Source code for memorax.networks.blocks.segment_recurrence

from typing import Any

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct

from memorax.networks.sequence_models.sequence_model import SequenceModel
from memorax.utils.axes import get_input_shape
from memorax.utils.typing import Array, Carry, Key

from .base import Block


@struct.dataclass
class Memory:
    state: Array
    done: Array


[docs] class SegmentRecurrence(nn.Module, Block): """Wraps a sequence model with segment-level recurrence memory. This block maintains a fixed-length memory of past outputs that can be used by the wrapped sequence model for cross-segment attention. Args: sequence_model: The underlying sequence model to wrap. memory_length: Maximum number of past timesteps to retain. features: Feature dimension of the memory. dtype: Data type for memory storage. """ module: SequenceModel memory_length: int features: int dtype: Any = None
[docs] @nn.nowrap def initialize_carry(self, key: Key, input_shape: tuple) -> Carry: *batch_dims, _ = input_shape state = jnp.zeros( (*batch_dims, self.memory_length, self.features), dtype=self.dtype ) done = jnp.zeros((*batch_dims, self.memory_length), dtype=jnp.int32) memory = Memory(state=state, done=done) carry = self.module.initialize_carry(key, input_shape) return (memory, carry)
[docs] @nn.compact def __call__( self, inputs: Array, done: Array | None = None, initial_carry: Carry | None = None, **kwargs, ) -> tuple[Carry, Array]: if initial_carry is None: input_shape = get_input_shape(inputs) initial_carry = self.initialize_carry(jax.random.key(0), input_shape) if done is None: batch_size, seq_len, *_ = inputs.shape done = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) memory, carry = initial_carry carry, y = self.module( inputs, done, initial_carry=carry, memory=memory.state, memory_done=memory.done, **kwargs, ) state = jnp.concatenate([memory.state, jax.lax.stop_gradient(y)], axis=1) state = state[:, -self.memory_length :] done = jnp.concatenate([memory.done, done], axis=1) done = done[:, -self.memory_length :] memory = Memory(state=state, done=done) return (memory, carry), y