Source code for memorax.networks.sequence_models.wrappers
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax import struct
from memorax.utils.typing import Array, Carry, Key
from .sequence_model import SequenceModel
[docs]
class SequenceModelWrapper(SequenceModel, nn.Module):
network: nn.Module
[docs]
def __call__(self, inputs: Array, done: Array, initial_carry: Carry | None = None, **kwargs) -> tuple[Carry, Array]:
carry = initial_carry
return carry, self.network(inputs, **kwargs)
[docs]
def initialize_carry(self, key: Key, input_shape: tuple) -> None:
return None
@struct.dataclass
class RL2State:
carry: Array
step: Array
[docs]
class RL2Wrapper(SequenceModel, nn.Module):
sequence_model: nn.Module
steps_per_trial: int
[docs]
def __call__(self, inputs: Array, done: Array, initial_carry: Carry | None = None, **kwargs) -> tuple[RL2State, Array]:
_, sequence_length, *_ = inputs.shape
if initial_carry is None:
initial_carry = self.initialize_carry(jax.random.key(0), inputs.shape)
time_indices = jnp.arange(sequence_length)
steps = initial_carry.step[:, None] + time_indices[None, :]
done = steps % self.steps_per_trial == 0
carry, outputs = self.sequence_model(inputs, done, initial_carry.carry)
carry = RL2State(carry=carry, step=initial_carry.step + sequence_length)
return carry, outputs
[docs]
def initialize_carry(self, key: Key, input_shape: tuple) -> RL2State:
batch_size, *_, features = input_shape
return RL2State(
carry=self.sequence_model.initialize_carry(key, (batch_size, features)),
step=jnp.zeros((batch_size,), dtype=jnp.int32),
)