Source code for memorax.networks.sequence_models.rtrl
import jax
from flax import linen as nn
from memorax.utils.typing import Array, Carry
from .sequence_model import SequenceModel
from memorax.utils.axes import get_input_shape
[docs]
class RTRL(SequenceModel):
sequence_model: SequenceModel
[docs]
@nn.compact
def __call__(
self,
inputs: Array,
done: Array,
initial_carry: Carry | None = None,
**kwargs,
) -> tuple[Carry, Array]:
if initial_carry is None:
input_shape = get_input_shape(inputs)
initial_carry = self.initialize_carry(jax.random.key(0), input_shape)
carry, sensitivity = initial_carry
assert sensitivity is not None, (
f"{type(self.sequence_model).__name__} does not support RTRL. "
"Ensure the inner model implements local_jacobian and initialize_sensitivity."
)
next_carry, y, next_sensitivity = self.sequence_model.local_jacobian(
inputs, done, carry, sensitivity=sensitivity, **kwargs
)
return (next_carry, next_sensitivity), y
[docs]
def initialize_carry(self, key: jax.Array, input_shape: tuple[int, ...]) -> Carry:
carry = self.sequence_model.initialize_carry(key, input_shape)
sensitivity = self.sequence_model.initialize_sensitivity(key, input_shape)
return (carry, sensitivity)