Source code for memorax.networks.sequence_models.rnn

from abc import abstractmethod
from typing import Mapping

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.core.scope import CollectionFilter, PRNGSequenceFilter
from flax.typing import InOutScanAxis

from memorax.utils.axes import (
    add_feature_axis,
    broadcast_done,
    get_time_axis_and_input_shape,
    reset_carry,
)
from memorax.utils.typing import Array, Carry, Key

from .sequence_model import SequenceModel


class RNNCellBase(nn.recurrent.RNNCellBase):
    @abstractmethod
    def local_jacobian(
        self, carry: Carry, inputs: Array, sensitivity: dict[str, Array], **kwargs
    ) -> tuple[Carry, Array, dict[str, Array]]: ...

    def compute_phantom(self, sensitivity: dict[str, Array]) -> Array:
        params = self.variables["params"]
        phantom = 0
        for name, S in sensitivity.items():
            param = params
            for key in name.split("/"):
                param = param[key]
            diff = param - jax.lax.stop_gradient(param)
            phantom = phantom + jnp.sum(S * diff, axis=tuple(range(3, S.ndim)))
        return phantom

    @abstractmethod
    def inject_phantom(self, carry: Carry, phantom: Array) -> Carry: ...

    @abstractmethod
    def initialize_sensitivity(
        self, key: jax.Array, input_shape: tuple[int, ...]
    ) -> dict[str, Array] | None: ...


[docs] class RNN(SequenceModel): cell: nn.RNNCellBase unroll: int = 1 variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict() variable_broadcast: CollectionFilter = "params" variable_carry: CollectionFilter = False split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict({"params": False})
[docs] def __call__( self, inputs: Array, done: Array, initial_carry: Carry | None = None, **kwargs, ) -> tuple[Carry, Array]: time_axis, input_shape = get_time_axis_and_input_shape(inputs) if initial_carry is None: initial_carry = self.cell.initialize_carry(jax.random.key(0), input_shape) carry: Carry = initial_carry def scan_fn(cell, carry, x, done): carry = reset_carry( done, carry, self.cell.initialize_carry(jax.random.key(0), input_shape) ) carry, y = cell(carry, x) return carry, y scan = nn.transforms.scan( scan_fn, in_axes=time_axis, out_axes=time_axis, unroll=self.unroll, variable_axes=self.variable_axes, variable_broadcast=self.variable_broadcast, variable_carry=self.variable_carry, split_rngs=self.split_rngs, ) carry, outputs = scan(self.cell, carry, inputs, done) return carry, outputs
[docs] @nn.nowrap def initialize_carry(self, key: jax.Array, input_shape: tuple[int, ...]) -> Carry: return self.cell.initialize_carry(key, input_shape)
[docs] def local_jacobian(self, inputs: Array, done: Array, carry: Carry, sensitivity: dict[str, Array] | None = None, **kwargs) -> tuple[Carry, Array, dict[str, Array] | None]: if sensitivity is None: next_carry, y = self(inputs, done, carry, **kwargs) return next_carry, y, None time_axis, input_shape = get_time_axis_and_input_shape(inputs) initial_carry = self.cell.initialize_carry(jax.random.key(0), input_shape) def scan_fn(cell, state, x, done_t): cell_carry, sensitivity = state phantom = cell.compute_phantom(sensitivity) cell_carry = cell.inject_phantom(cell_carry, phantom) cell_carry = reset_carry(done_t, cell_carry, initial_carry) sensitivity = jax.tree.map( lambda s: jnp.where(broadcast_done(add_feature_axis(done_t), s), 0, s), sensitivity, ) next_carry, y, next_sensitivity = cell.local_jacobian( cell_carry, x, sensitivity ) return (next_carry, next_sensitivity), y scan = nn.transforms.scan( scan_fn, in_axes=time_axis, out_axes=time_axis, unroll=self.unroll, variable_axes=self.variable_axes, variable_broadcast=self.variable_broadcast, variable_carry=self.variable_carry, split_rngs=self.split_rngs, ) (next_carry, next_sensitivity), outputs = scan( self.cell, (carry, sensitivity), inputs, done ) return next_carry, outputs, next_sensitivity
[docs] def initialize_sensitivity(self, key: Key, input_shape: tuple) -> dict[str, Array] | None: return self.cell.initialize_sensitivity(key, input_shape)