from abc import abstractmethod
from typing import Mapping
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.core.scope import CollectionFilter, PRNGSequenceFilter
from flax.typing import InOutScanAxis
from memorax.utils.axes import (
add_feature_axis,
broadcast_done,
get_time_axis_and_input_shape,
reset_carry,
)
from memorax.utils.typing import Array, Carry, Key
from .sequence_model import SequenceModel
class RNNCellBase(nn.recurrent.RNNCellBase):
@abstractmethod
def local_jacobian(
self, carry: Carry, inputs: Array, sensitivity: dict[str, Array], **kwargs
) -> tuple[Carry, Array, dict[str, Array]]: ...
def compute_phantom(self, sensitivity: dict[str, Array]) -> Array:
params = self.variables["params"]
phantom = 0
for name, S in sensitivity.items():
param = params
for key in name.split("/"):
param = param[key]
diff = param - jax.lax.stop_gradient(param)
phantom = phantom + jnp.sum(S * diff, axis=tuple(range(3, S.ndim)))
return phantom
@abstractmethod
def inject_phantom(self, carry: Carry, phantom: Array) -> Carry: ...
@abstractmethod
def initialize_sensitivity(
self, key: jax.Array, input_shape: tuple[int, ...]
) -> dict[str, Array] | None: ...
[docs]
class RNN(SequenceModel):
cell: nn.RNNCellBase
unroll: int = 1
variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict()
variable_broadcast: CollectionFilter = "params"
variable_carry: CollectionFilter = False
split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict({"params": False})
[docs]
def __call__(
self,
inputs: Array,
done: Array,
initial_carry: Carry | None = None,
**kwargs,
) -> tuple[Carry, Array]:
time_axis, input_shape = get_time_axis_and_input_shape(inputs)
if initial_carry is None:
initial_carry = self.cell.initialize_carry(jax.random.key(0), input_shape)
carry: Carry = initial_carry
def scan_fn(cell, carry, x, done):
carry = reset_carry(
done, carry, self.cell.initialize_carry(jax.random.key(0), input_shape)
)
carry, y = cell(carry, x)
return carry, y
scan = nn.transforms.scan(
scan_fn,
in_axes=time_axis,
out_axes=time_axis,
unroll=self.unroll,
variable_axes=self.variable_axes,
variable_broadcast=self.variable_broadcast,
variable_carry=self.variable_carry,
split_rngs=self.split_rngs,
)
carry, outputs = scan(self.cell, carry, inputs, done)
return carry, outputs
[docs]
@nn.nowrap
def initialize_carry(self, key: jax.Array, input_shape: tuple[int, ...]) -> Carry:
return self.cell.initialize_carry(key, input_shape)
[docs]
def local_jacobian(self, inputs: Array, done: Array, carry: Carry, sensitivity: dict[str, Array] | None = None, **kwargs) -> tuple[Carry, Array, dict[str, Array] | None]:
if sensitivity is None:
next_carry, y = self(inputs, done, carry, **kwargs)
return next_carry, y, None
time_axis, input_shape = get_time_axis_and_input_shape(inputs)
initial_carry = self.cell.initialize_carry(jax.random.key(0), input_shape)
def scan_fn(cell, state, x, done_t):
cell_carry, sensitivity = state
phantom = cell.compute_phantom(sensitivity)
cell_carry = cell.inject_phantom(cell_carry, phantom)
cell_carry = reset_carry(done_t, cell_carry, initial_carry)
sensitivity = jax.tree.map(
lambda s: jnp.where(broadcast_done(add_feature_axis(done_t), s), 0, s),
sensitivity,
)
next_carry, y, next_sensitivity = cell.local_jacobian(
cell_carry, x, sensitivity
)
return (next_carry, next_sensitivity), y
scan = nn.transforms.scan(
scan_fn,
in_axes=time_axis,
out_axes=time_axis,
unroll=self.unroll,
variable_axes=self.variable_axes,
variable_broadcast=self.variable_broadcast,
variable_carry=self.variable_carry,
split_rngs=self.split_rngs,
)
(next_carry, next_sensitivity), outputs = scan(
self.cell, (carry, sensitivity), inputs, done
)
return next_carry, outputs, next_sensitivity
[docs]
def initialize_sensitivity(self, key: Key, input_shape: tuple) -> dict[str, Array] | None:
return self.cell.initialize_sensitivity(key, input_shape)