import math
from abc import abstractmethod
import jax
import jax.numpy as jnp
from flax import linen as nn
from memorax.utils.axes import (
add_feature_axis,
broadcast_done,
get_input_shape,
init,
last,
tail,
)
from memorax.utils.typing import Array, Carry, Key
from .sequence_model import SequenceModel
[docs]
class MemoroidCellBase(nn.Module):
[docs]
@abstractmethod
def __call__(self, x: Array, **kwargs) -> Carry: ...
[docs]
@abstractmethod
def binary_operator(self, a: Carry, b: Carry) -> Carry: ...
[docs]
@abstractmethod
def read(self, h: Carry, x: Array, **kwargs) -> Array: ...
[docs]
@abstractmethod
def initialize_carry(
self, key: jax.Array, input_shape: tuple[int, ...]
) -> Carry: ...
[docs]
def local_jacobian(self, carry: Carry, z: Carry, inputs: Array, **kwargs) -> tuple[Array, dict] | None:
return None
[docs]
def compute_phantom(self, sensitivity: dict) -> Array:
params = self.variables["params"]
phantom = 0
for name, S in sensitivity.items():
param = params
for key in name.split("/"):
param = param[key]
diff = param - jax.lax.stop_gradient(param)
phantom = phantom + jnp.sum(S * diff, axis=tuple(range(3, S.ndim)))
return phantom
[docs]
def inject_phantom(self, carry: Carry, phantom: Array) -> Carry:
state, *rest = carry
return (jax.lax.stop_gradient(state) + phantom, *rest)
[docs]
def initialize_sensitivity(self, key: Key, input_shape: tuple) -> dict | None:
return None
[docs]
class Memoroid(SequenceModel):
cell: MemoroidCellBase
[docs]
def scan_fn(self, z, initial_carry, done):
z = jax.tree.map(
lambda c, e: jnp.concatenate([c, e], axis=1),
initial_carry,
z,
)
reset = jnp.concatenate([jnp.zeros((done.shape[0], 1)), done], axis=1)
reset = add_feature_axis(reset)
cell = self.cell
@jax.vmap
def binary_operator(lhs, rhs):
lhs_carry, lhs_reset = lhs
rhs_carry, rhs_reset = rhs
combined = cell.binary_operator(lhs_carry, rhs_carry)
out = jax.tree.map(
lambda rc, c: jnp.where(broadcast_done(rhs_reset, rc), rc, c),
rhs_carry,
combined,
)
return out, jnp.maximum(lhs_reset, rhs_reset)
h, _ = jax.lax.associative_scan(binary_operator, (z, reset), axis=1)
next_carry = jax.tree.map(last, h)
h = jax.tree.map(tail, h)
return h, next_carry
[docs]
@nn.compact
def __call__(
self,
inputs: Array,
done: Array,
initial_carry: Carry | None = None,
**kwargs,
) -> tuple[Carry, Array]:
if initial_carry is None:
input_shape = get_input_shape(inputs)
initial_carry = self.cell.initialize_carry(jax.random.key(0), input_shape)
z = self.cell(inputs, **kwargs)
h, next_carry = self.scan_fn(z, initial_carry, done)
y = self.cell.read(h, inputs, **kwargs)
return next_carry, y
[docs]
def initialize_carry(self, key: jax.Array, input_shape: tuple[int, ...]) -> Carry:
return self.cell.initialize_carry(key, input_shape)
def _propagate_sensitivities(self, decay: Array, jacobians: dict, sensitivity: dict, done: Array) -> dict:
B, T, H = decay.shape
done = add_feature_axis(done)
next_sensitivity = {}
@jax.vmap
def binary_operator(a, b):
state_i, decay_i = a
state_j, decay_j = b
return (decay_j * state_i + state_j, decay_j * decay_i)
for name in sorted(jacobians.keys()):
J = jacobians[name]
S = sensitivity[name]
_, _, _, *param_shape = J.shape
param_size = math.prod(param_shape)
J = J.reshape(B, T, H * param_size)
S = S.reshape(B, 1, H * param_size)
a = jnp.where(done, 0, jnp.repeat(decay, param_size, axis=-1))
state = jnp.concatenate([S, J], axis=1)
a = jnp.concatenate([jnp.ones_like(S), a], axis=1)
state, _ = jax.lax.associative_scan(binary_operator, (state, a), axis=1)
next_sensitivity[name] = last(state).reshape(B, 1, H, *param_shape)
return next_sensitivity
[docs]
@nn.compact
def local_jacobian(self, inputs: Array, done: Array, carry: Carry, sensitivity: dict | None = None, **kwargs) -> tuple[Carry, Array, dict | None]:
z = self.cell(inputs, **kwargs)
if sensitivity is not None:
phantom = self.cell.compute_phantom(sensitivity)
carry = self.cell.inject_phantom(carry, phantom)
h, next_carry = self.scan_fn(z, carry, done)
y = self.cell.read(h, inputs, **kwargs)
next_sensitivity = None
if sensitivity is not None:
prev_carry = jax.tree.map(
lambda initial_carry, hidden_states: jnp.concatenate(
[initial_carry, init(hidden_states)], axis=1
),
carry,
h,
)
reset = add_feature_axis(done)
prev_carry = jax.tree.map(
lambda c: jnp.where(broadcast_done(reset, c), 0, c),
prev_carry,
)
decay, jacobians = self.cell.local_jacobian(prev_carry, z, inputs)
if jacobians:
next_sensitivity = self._propagate_sensitivities(
decay, jacobians, sensitivity, done
)
else:
next_sensitivity = sensitivity
return next_carry, y, next_sensitivity
[docs]
def initialize_sensitivity(self, key: jax.Array, input_shape: tuple[int, ...]) -> dict | None:
return self.cell.initialize_sensitivity(key, input_shape)