Source code for memorax.networks.sequence_models.wrappers

from typing import Optional

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax import struct

from .sequence_model import SequenceModel


[docs] class SequenceModelWrapper(SequenceModel, nn.Module): network: nn.Module
[docs] def __call__(self, inputs, mask, initial_carry=None, **kwargs): carry = initial_carry return carry, self.network(inputs, **kwargs)
[docs] def initialize_carry(self, key, input_shape): batch_size, _ = input_shape return jnp.zeros((batch_size, 1))
@struct.dataclass class MetaMaskState: carry: jnp.ndarray step: jnp.ndarray
[docs] class MetaMaskWrapper(SequenceModel, nn.Module): sequence_model: nn.Module steps_per_trial: int
[docs] def __call__(self, inputs, mask, initial_carry=None, **kwargs): _, sequence_length, *_ = inputs.shape if initial_carry is None: initial_carry = self.initialize_carry(jax.random.key(0), inputs.shape) time_indices = jnp.arange(sequence_length) steps = initial_carry.step[:, None] + time_indices[None, :] mask = steps % self.steps_per_trial != 0 carry, outputs = self.sequence_model(inputs, mask, initial_carry.carry) carry = MetaMaskState(carry=carry, step=initial_carry.step + sequence_length) return carry, outputs
[docs] def initialize_carry(self, key, input_shape): batch_size, *_, features = input_shape return MetaMaskState( carry=self.sequence_model.initialize_carry(key, (batch_size, features)), step=jnp.zeros((batch_size,), dtype=jnp.int32), )