Source code for memorax.networks.sequence_models.s5

from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp
from flax import struct
from flax.typing import Dtype
from jax.nn.initializers import lecun_normal, normal

from memorax.utils.typing import Array, Carry

from memorax.networks.hippo import discretize_bilinear, discretize_zoh, make_dplr_hippo
from memorax.networks.initializers import init_cv, init_v_inv_b, log_step, truncated_standard_normal

from .memoroid import MemoroidCellBase


def _c_init_complex_normal(key, shape):
    return normal(stddev=0.5**0.5)(key, shape)


def _c_init_lecun(key, shape, v):
    return init_cv(lecun_normal(), key, shape, v)


def _c_init_truncated(key, shape, v):
    return init_cv(truncated_standard_normal, key, shape, v)


@struct.dataclass
class S5Config:
    features: int
    hidden_dim: int
    c_init: str = "truncated_standard_normal"
    discretization: str = "zoh"
    dt_min: float = 0.001
    dt_max: float = 0.1
    clip_eigens: bool = False
    step_rescale: float = 1.0
    dtype: Dtype | None = None
    param_dtype: Dtype = jnp.float32


@struct.dataclass
class S5Carry:
    state: Array
    decay: Array


[docs] class S5Cell(MemoroidCellBase): config: S5Config
[docs] def setup(self): lam, _, _, v, _ = make_dplr_hippo(self.config.hidden_dim) lambda_real_init = jnp.asarray(lam.real, self.config.param_dtype) lambda_imag_init = jnp.asarray(lam.imag, self.config.param_dtype) self._v = v self._v_inv = v.conj().T self.lambda_real = self.param( "lambda_real", lambda rng, shape: lambda_real_init, (self.config.hidden_dim,), ) self.lambda_imag = self.param( "lambda_imag", lambda rng, shape: lambda_imag_init, (self.config.hidden_dim,), ) self.B = self.param( "b", lambda rng, shape: init_v_inv_b( lecun_normal(), rng, (self.config.hidden_dim, self.config.features), self._v_inv ), (self.config.hidden_dim, self.config.features, 2), ) match self.config.c_init: case "complex_normal": self.C = self.param( "c", _c_init_complex_normal, (self.config.features, self.config.hidden_dim, 2), ) case "lecun_normal": self.C = self.param( "c", partial(_c_init_lecun, v=self._v), (self.config.features, self.config.hidden_dim, 2), ) case "truncated_standard_normal": self.C = self.param( "c", partial(_c_init_truncated, v=self._v), (self.config.features, self.config.hidden_dim, 2), ) case _: raise ValueError(f"Invalid c_init: {self.config.c_init}") self.d = self.param("d", normal(stddev=1.0), (self.config.features,)) self.log_step = self.param( "log_step", log_step(self.config.dt_min, self.config.dt_max), (self.config.hidden_dim,) )
def _discretized_params(self): lambda_real = self.lambda_real lambda_imag = self.lambda_imag if self.config.clip_eigens: lambda_real = jnp.minimum(lambda_real, -1e-4) lam = jax.lax.complex(lambda_real, lambda_imag) b_tilde = jax.lax.complex(self.B[..., 0], self.B[..., 1]) c_tilde = jax.lax.complex(self.C[..., 0], self.C[..., 1]) step = self.config.step_rescale * jnp.exp(self.log_step[:, 0].astype(jnp.float32)) match self.config.discretization: case "zoh": lambda_bar, b_bar = discretize_zoh( lam.astype(jnp.complex64), b_tilde.astype(jnp.complex64), step.astype(jnp.complex64), ) case "bilinear": lambda_bar, b_bar = discretize_bilinear( lam.astype(jnp.complex64), b_tilde.astype(jnp.complex64), step.astype(jnp.complex64), ) case _: raise ValueError(f"Invalid discretization: {self.config.discretization}") d = self.d.astype(self.config.dtype or self.config.param_dtype) return lambda_bar, b_bar, c_tilde, d, lam, b_tilde, step
[docs] def __call__(self, x: Array, **kwargs) -> Carry: B, T, _ = x.shape lambda_bar, b_bar, _, _, _, _, _ = self._discretized_params() decay = jnp.broadcast_to(lambda_bar, (B, T, self.config.hidden_dim)) state = jnp.einsum('ij,btj->bti', b_bar, x) return S5Carry(state=state, decay=decay)
[docs] def binary_operator(self, a: Carry, b: Carry) -> Carry: return S5Carry( state=b.decay * a.state + b.state, decay=b.decay * a.decay, )
[docs] def read(self, h: Carry, x: Array, **kwargs) -> Array: _, _, c_tilde, d, _, _, _ = self._discretized_params() y = jnp.einsum('ij,btj->bti', c_tilde, h.state).real + d * x return y
[docs] def initialize_carry(self, key: jax.Array, input_shape: Tuple[int, ...]) -> Carry: *batch_dims, _ = input_shape state = jnp.zeros((*batch_dims, 1, self.config.hidden_dim), dtype=jnp.complex64) decay = jnp.ones((*batch_dims, 1, self.config.hidden_dim), dtype=jnp.complex64) return S5Carry(state=state, decay=decay)