Source code for memorax.networks.blocks.segment_recurrence
from typing import Any
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct
from memorax.networks.sequence_models.sequence_model import SequenceModel
from memorax.utils.axes import get_input_shape
from memorax.utils.typing import Array, Carry, Key
from .base import Block
@struct.dataclass
class Memory:
state: Array
done: Array
[docs]
class SegmentRecurrence(nn.Module, Block):
"""Wraps a sequence model with segment-level recurrence memory.
This block maintains a fixed-length memory of past outputs that can be
used by the wrapped sequence model for cross-segment attention.
Args:
sequence_model: The underlying sequence model to wrap.
memory_length: Maximum number of past timesteps to retain.
features: Feature dimension of the memory.
dtype: Data type for memory storage.
"""
module: SequenceModel
memory_length: int
features: int
dtype: Any = None
[docs]
@nn.nowrap
def initialize_carry(self, key: Key, input_shape: tuple) -> Carry:
*batch_dims, _ = input_shape
state = jnp.zeros(
(*batch_dims, self.memory_length, self.features), dtype=self.dtype
)
done = jnp.zeros((*batch_dims, self.memory_length), dtype=jnp.int32)
memory = Memory(state=state, done=done)
carry = self.module.initialize_carry(key, input_shape)
return (memory, carry)
[docs]
@nn.compact
def __call__(
self,
inputs: Array,
done: Array | None = None,
initial_carry: Carry | None = None,
**kwargs,
) -> tuple[Carry, Array]:
if initial_carry is None:
input_shape = get_input_shape(inputs)
initial_carry = self.initialize_carry(jax.random.key(0), input_shape)
if done is None:
batch_size, seq_len, *_ = inputs.shape
done = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
memory, carry = initial_carry
carry, y = self.module(
inputs,
done,
initial_carry=carry,
memory=memory.state,
memory_done=memory.done,
**kwargs,
)
state = jnp.concatenate([memory.state, jax.lax.stop_gradient(y)], axis=1)
state = state[:, -self.memory_length :]
done = jnp.concatenate([memory.done, done], axis=1)
done = done[:, -self.memory_length :]
memory = Memory(state=state, done=done)
return (memory, carry), y