Source code for memorax.networks.sequence_models.s5

from typing import Tuple

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.typing import Dtype
from jax.nn.initializers import lecun_normal, normal
from memorax.utils.typing import Array, Carry

from .memoroid import MemoroidCellBase
from .utils import (
    discretize_bilinear,
    discretize_zoh,
    init_cv,
    init_log_steps,
    init_v_inv_b,
    make_dplr_hippo,
    truncated_standard_normal,
)


[docs] class S5Cell(MemoroidCellBase): """S5 (Structured State Space for Sequences) algebra. Uses HIPPO matrix initialization and discretization for stable long-range dependencies. Element: (decay, state) where state accumulates B @ x Combine: (a_j * a_i, a_j * s_i + s_j) """ features: int state_size: int c_init: str = "truncated_standard_normal" discretization: str = "zoh" dt_min: float = 0.001 dt_max: float = 0.1 clip_eigens: bool = False step_rescale: float = 1.0 dtype: Dtype = jnp.float32 param_dtype: Dtype = jnp.float32
[docs] def setup(self): lam, _, _, v, _ = make_dplr_hippo(self.state_size) self._lambda_real_init = jnp.asarray(lam.real, self.param_dtype) self._lambda_imag_init = jnp.asarray(lam.imag, self.param_dtype) self._v = v self._v_inv = v.conj().T
def _discretized_params(self): lambda_real = self.param( "lambda_real", lambda rng, shape: self._lambda_real_init, (self.state_size,), ) lambda_imag = self.param( "lambda_imag", lambda rng, shape: self._lambda_imag_init, (self.state_size,), ) if self.clip_eigens: lambda_real = jnp.minimum(lambda_real, -1e-4) lam = jax.lax.complex(lambda_real, lambda_imag) b = self.param( "b", lambda rng, shape: init_v_inv_b( lecun_normal(), rng, (self.state_size, self.features), self._v_inv ), (self.state_size, self.features, 2), ) b_tilde = jax.lax.complex(b[..., 0], b[..., 1]) match self.c_init: case "complex_normal": c = self.param( "c", normal(stddev=0.5**0.5), (self.features, self.state_size, 2) ) c_tilde = jax.lax.complex(c[..., 0], c[..., 1]) case "lecun_normal": c = self.param( "c", lambda rng, shape: init_cv( lecun_normal(), rng, (self.features, self.state_size, 2), self._v, ), (self.features, self.state_size, 2), ) c_tilde = jax.lax.complex(c[..., 0], c[..., 1]) case "truncated_standard_normal": c = self.param( "c", lambda rng, shape: init_cv( truncated_standard_normal, rng, (self.features, self.state_size, 2), self._v, ), (self.features, self.state_size, 2), ) c_tilde = jax.lax.complex(c[..., 0], c[..., 1]) case _: raise ValueError(f"Invalid c_init: {self.c_init}") d = self.param("d", normal(stddev=1.0), (self.features,)) log_step = self.param( "log_step", init_log_steps, (self.state_size, self.dt_min, self.dt_max) ) step = self.step_rescale * jnp.exp(log_step[:, 0].astype(jnp.float32)) match self.discretization: case "zoh": lambda_bar, b_bar = discretize_zoh( lam.astype(jnp.complex64), b_tilde.astype(jnp.complex64), step.astype(jnp.complex64), ) case "bilinear": lambda_bar, b_bar = discretize_bilinear( lam.astype(jnp.complex64), b_tilde.astype(jnp.complex64), step.astype(jnp.complex64), ) case _: raise ValueError(f"Invalid discretization: {self.discretization}") return lambda_bar, b_bar, c_tilde, d.astype(self.dtype)
[docs] @nn.compact def __call__(self, x: Array, **kwargs) -> Carry: B, T, _ = x.shape lambda_bar, b_bar, _, _ = self._discretized_params() # Decay: broadcast lambda to (B, T, state_size) decay = jnp.broadcast_to(lambda_bar, (B, T, self.state_size)) # State: B @ x for each timestep state = jax.vmap(jax.vmap(lambda xi: b_bar @ xi))(x) return (decay, state)
[docs] def binary_operator(self, a: Carry, b: Carry) -> Carry: """Diagonal SSM combine: (a_j * a_i, a_j * s_i + s_j)""" decay_i, state_i = a decay_j, state_j = b return (decay_j * decay_i, decay_j * state_i + state_j)
[docs] @nn.compact def read(self, h: Carry, x: Array, **kwargs) -> Array: _, _, c_tilde, d = self._discretized_params() _, state = h # Output: C @ state + D * x y = jax.vmap(jax.vmap(lambda si, xi: (c_tilde @ si).real + d * xi))(state, x) return y
[docs] def initialize_carry(self, key: jax.Array, input_shape: Tuple[int, ...]) -> Carry: batch_size, *_ = input_shape decay = jnp.ones((batch_size, 1, self.state_size), dtype=jnp.complex64) state = jnp.zeros((batch_size, 1, self.state_size), dtype=jnp.complex64) return (decay, state)