Source code for memorax.networks.sequence_models.lru

from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.typing import Dtype
from memorax.utils.typing import Array, Carry

from .memoroid import MemoroidCellBase


def _nu_init(key, shape, r_min, r_max, dtype=jnp.float32):
    u = jax.random.uniform(key=key, shape=shape, dtype=dtype)
    return jnp.log(-0.5 * jnp.log(u * (r_max**2 - r_min**2) + r_min**2))


def _theta_init(key, shape, max_phase, dtype=jnp.float32):
    u = jax.random.uniform(key, shape=shape, dtype=dtype)
    return jnp.log(max_phase * u)


def _gamma_log_init(key, lamb):
    nu, theta = lamb
    diag_lambda = jnp.exp(-jnp.exp(nu) + 1j * jnp.exp(theta))
    return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2))


def _matrix_init(key, shape, dtype=jnp.float32, normalization=1):
    return jax.random.normal(key=key, shape=shape, dtype=dtype) / normalization


[docs] class LRUCell(MemoroidCellBase): """Linear Recurrent Unit algebra. Uses exponential parameterization of eigenvalues for stable training. Element: (decay, state) Combine: (a_j * a_i, a_j * s_i + s_j) """ features: int hidden_dim: int r_min: float = 0.0 r_max: float = 1.0 max_phase: float = 6.28 dtype: Dtype = jnp.float32 param_dtype: Dtype = jnp.float32
[docs] def setup(self): self.theta_log = self.param( "theta_log", partial(_theta_init, max_phase=self.max_phase), (self.hidden_dim,), ) self.nu_log = self.param( "nu_log", partial(_nu_init, r_min=self.r_min, r_max=self.r_max), (self.hidden_dim,), ) self.gamma_log = self.param( "gamma_log", _gamma_log_init, (self.nu_log, self.theta_log) ) self.B_real = self.param( "B_real", partial(_matrix_init, normalization=jnp.sqrt(2 * self.features)), (self.hidden_dim, self.features), ) self.B_imag = self.param( "B_imag", partial(_matrix_init, normalization=jnp.sqrt(2 * self.features)), (self.hidden_dim, self.features), ) self.C_real = self.param( "C_real", partial(_matrix_init, normalization=jnp.sqrt(self.hidden_dim)), (self.features, self.hidden_dim), ) self.C_imag = self.param( "C_imag", partial(_matrix_init, normalization=jnp.sqrt(self.hidden_dim)), (self.features, self.hidden_dim), ) self.D = self.param("D", _matrix_init, (self.features,))
[docs] def __call__(self, x: Array, **kwargs) -> Carry: B, T, _ = x.shape diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log)) B_norm = (self.B_real + 1j * self.B_imag) * jnp.expand_dims( jnp.exp(self.gamma_log), axis=-1 ) decay = jnp.broadcast_to(diag_lambda, (B, T, self.hidden_dim)) # State: B @ x for each timestep state = jax.vmap(jax.vmap(lambda u: B_norm @ u))(x) return (decay, state)
[docs] def binary_operator(self, a: Carry, b: Carry) -> Carry: """Diagonal SSM combine: (a_j * a_i, a_j * s_i + s_j)""" decay_i, state_i = a decay_j, state_j = b return (decay_j * decay_i, decay_j * state_i + state_j)
[docs] def read(self, h: Carry, x: Array, **kwargs) -> Array: C = jax.lax.complex(self.C_real, self.C_imag) _, state = h # Output: C @ state + D * x y = jax.vmap(jax.vmap(lambda si, xi: (C @ si).real + self.D * xi))(state, x) return y
[docs] def initialize_carry(self, key: jax.Array, input_shape: Tuple[int, ...]) -> Carry: batch_size, *_ = input_shape decay = jnp.ones((batch_size, 1, self.hidden_dim), dtype=jnp.complex64) state = jnp.zeros((batch_size, 1, self.hidden_dim), dtype=jnp.complex64) return (decay, state)