Source code for memorax.networks.sequence_models.memoroid

import math
from abc import abstractmethod
import jax
import jax.numpy as jnp
from flax import linen as nn

from memorax.utils.axes import (
    add_feature_axis,
    broadcast_done,
    get_input_shape,
    init,
    last,
    tail,
)
from memorax.utils.typing import Array, Carry, Key

from .sequence_model import SequenceModel


[docs] class MemoroidCellBase(nn.Module):
[docs] @abstractmethod def __call__(self, x: Array, **kwargs) -> Carry: ...
[docs] @abstractmethod def binary_operator(self, a: Carry, b: Carry) -> Carry: ...
[docs] @abstractmethod def read(self, h: Carry, x: Array, **kwargs) -> Array: ...
[docs] @abstractmethod def initialize_carry(
self, key: jax.Array, input_shape: tuple[int, ...] ) -> Carry: ...
[docs] def local_jacobian(self, carry: Carry, z: Carry, inputs: Array, **kwargs) -> tuple[Array, dict] | None: return None
[docs] def compute_phantom(self, sensitivity: dict) -> Array: params = self.variables["params"] phantom = 0 for name, S in sensitivity.items(): param = params for key in name.split("/"): param = param[key] diff = param - jax.lax.stop_gradient(param) phantom = phantom + jnp.sum(S * diff, axis=tuple(range(3, S.ndim))) return phantom
[docs] def inject_phantom(self, carry: Carry, phantom: Array) -> Carry: state, *rest = carry return (jax.lax.stop_gradient(state) + phantom, *rest)
[docs] def initialize_sensitivity(self, key: Key, input_shape: tuple) -> dict | None: return None
[docs] class Memoroid(SequenceModel): cell: MemoroidCellBase
[docs] def scan_fn(self, z, initial_carry, done): z = jax.tree.map( lambda c, e: jnp.concatenate([c, e], axis=1), initial_carry, z, ) reset = jnp.concatenate([jnp.zeros((done.shape[0], 1)), done], axis=1) reset = add_feature_axis(reset) cell = self.cell @jax.vmap def binary_operator(lhs, rhs): lhs_carry, lhs_reset = lhs rhs_carry, rhs_reset = rhs combined = cell.binary_operator(lhs_carry, rhs_carry) out = jax.tree.map( lambda rc, c: jnp.where(broadcast_done(rhs_reset, rc), rc, c), rhs_carry, combined, ) return out, jnp.maximum(lhs_reset, rhs_reset) h, _ = jax.lax.associative_scan(binary_operator, (z, reset), axis=1) next_carry = jax.tree.map(last, h) h = jax.tree.map(tail, h) return h, next_carry
[docs] @nn.compact def __call__( self, inputs: Array, done: Array, initial_carry: Carry | None = None, **kwargs, ) -> tuple[Carry, Array]: if initial_carry is None: input_shape = get_input_shape(inputs) initial_carry = self.cell.initialize_carry(jax.random.key(0), input_shape) z = self.cell(inputs, **kwargs) h, next_carry = self.scan_fn(z, initial_carry, done) y = self.cell.read(h, inputs, **kwargs) return next_carry, y
[docs] def initialize_carry(self, key: jax.Array, input_shape: tuple[int, ...]) -> Carry: return self.cell.initialize_carry(key, input_shape)
def _propagate_sensitivities(self, decay: Array, jacobians: dict, sensitivity: dict, done: Array) -> dict: B, T, H = decay.shape done = add_feature_axis(done) next_sensitivity = {} @jax.vmap def binary_operator(a, b): state_i, decay_i = a state_j, decay_j = b return (decay_j * state_i + state_j, decay_j * decay_i) for name in sorted(jacobians.keys()): J = jacobians[name] S = sensitivity[name] _, _, _, *param_shape = J.shape param_size = math.prod(param_shape) J = J.reshape(B, T, H * param_size) S = S.reshape(B, 1, H * param_size) a = jnp.where(done, 0, jnp.repeat(decay, param_size, axis=-1)) state = jnp.concatenate([S, J], axis=1) a = jnp.concatenate([jnp.ones_like(S), a], axis=1) state, _ = jax.lax.associative_scan(binary_operator, (state, a), axis=1) next_sensitivity[name] = last(state).reshape(B, 1, H, *param_shape) return next_sensitivity
[docs] @nn.compact def local_jacobian(self, inputs: Array, done: Array, carry: Carry, sensitivity: dict | None = None, **kwargs) -> tuple[Carry, Array, dict | None]: z = self.cell(inputs, **kwargs) if sensitivity is not None: phantom = self.cell.compute_phantom(sensitivity) carry = self.cell.inject_phantom(carry, phantom) h, next_carry = self.scan_fn(z, carry, done) y = self.cell.read(h, inputs, **kwargs) next_sensitivity = None if sensitivity is not None: prev_carry = jax.tree.map( lambda initial_carry, hidden_states: jnp.concatenate( [initial_carry, init(hidden_states)], axis=1 ), carry, h, ) reset = add_feature_axis(done) prev_carry = jax.tree.map( lambda c: jnp.where(broadcast_done(reset, c), 0, c), prev_carry, ) decay, jacobians = self.cell.local_jacobian(prev_carry, z, inputs) if jacobians: next_sensitivity = self._propagate_sensitivities( decay, jacobians, sensitivity, done ) else: next_sensitivity = sensitivity return next_carry, y, next_sensitivity
[docs] def initialize_sensitivity(self, key: jax.Array, input_shape: tuple[int, ...]) -> dict | None: return self.cell.initialize_sensitivity(key, input_shape)