Source code for memorax.networks.sequence_models.ffm

from functools import partial
from typing import Optional, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen import initializers
from flax.linen.linear import Dense, default_kernel_init
from flax.linen.normalization import LayerNorm
from flax.typing import Dtype, Initializer
from memorax.utils.typing import Array, Carry

from .memoroid import MemoroidCellBase


[docs] class FFMCell(MemoroidCellBase): """Fast and Forgetful Memory algebra. Uses position-relative decay with complex exponential basis functions for long-range dependencies. Element: (state, timestep) Combine: (state_i * γ(t_j - t_i) + state_j, t_j) The decay γ(Δt) = exp((a + ib) * Δt) where a controls decay rate and b controls oscillation frequency. """ features: int memory_size: int context_size: int min_period: int = 1 max_period: int = 1024 epsilon: float = 0.01 beta: float = 0.01 kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() dtype: Dtype | None = None param_dtype: Dtype = jnp.float32
[docs] def setup(self) -> None: self.limit = ( jnp.log(jnp.finfo(self.param_dtype).max) / self.max_period - self.epsilon ) low = -self.limit + self.epsilon high = jnp.maximum( jnp.minimum(-1e-6, jnp.log(self.beta) / self.max_period), low ) self.a = self.param( "a", lambda _: jnp.linspace(low, high, self.memory_size, dtype=self.param_dtype), ) self.b = self.param( "b", lambda _: (2 * jnp.pi) / jnp.linspace( self.min_period, self.max_period, self.context_size, dtype=self.param_dtype, ), )
def _complex_dtype(self): return jnp.complex64 if self.param_dtype == jnp.float32 else jnp.complex128 def _gamma(self, delta_t): """Compute decay factor for time difference delta_t.""" a = jnp.clip(self.a, min=-self.limit, max=-1e-8) a = a.reshape(1, self.memory_size, 1) b = self.b.reshape(1, 1, self.context_size) ab = jax.lax.complex(a, b) return jnp.exp(ab * delta_t[..., None, None])
[docs] @nn.compact def __call__(self, inputs: Array, **kwargs) -> Carry: B, T, _ = inputs.shape dense = partial( Dense, use_bias=True, kernel_init=self.kernel_init, bias_init=self.bias_init, dtype=self.dtype, param_dtype=self.param_dtype, ) # Input projections pre = dense(features=self.memory_size, name="pre")(inputs) input_gate = nn.sigmoid( dense(features=self.memory_size, name="input_gate")(inputs) ) x = pre * input_gate # Expand to context size and make complex state = jnp.repeat(x[..., None], self.context_size, axis=-1) state = jax.lax.complex(state, jnp.zeros_like(state)) # Timesteps (complex for compatibility) timestep = jnp.arange(T, dtype=self.param_dtype) timestep = jnp.broadcast_to(timestep, (B, T)) timestep = jax.lax.complex(timestep, jnp.zeros_like(timestep)) return (state, timestep)
[docs] def binary_operator(self, a: Carry, b: Carry) -> Carry: """Position-relative combine: state_i * γ(t_j - t_i) + state_j""" state_i, t_i = a state_j, t_j = b delta_t = t_j - t_i gamma = self._gamma(delta_t) return (state_i * gamma + state_j, t_j)
[docs] @nn.compact def read(self, h: Carry, x: Array, **kwargs) -> Array: state, _ = h dense = partial( Dense, use_bias=True, kernel_init=self.kernel_init, bias_init=self.bias_init, dtype=self.dtype, param_dtype=self.param_dtype, ) ln = LayerNorm(use_scale=False, use_bias=False, name="ln") # Output projections output_gate = nn.sigmoid(dense(features=self.features, name="output_gate")(x)) skip = dense(features=self.features, name="skip")(x) # Mix real and imaginary parts z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1) z_in = z_in.reshape((*z_in.shape[:-2], -1)) z = dense(features=self.features, name="mix")(z_in) y = ln(z) * output_gate + skip * (1.0 - output_gate) return y
[docs] def initialize_carry(self, key: jax.Array, input_shape: Tuple[int, ...]) -> Carry: batch_size, *_ = input_shape state = jnp.zeros( (batch_size, 1, self.memory_size, self.context_size), dtype=self._complex_dtype(), ) # Start at timestep -1 so first real timestep is 0 timestep = jnp.full((batch_size, 1), -1, dtype=self._complex_dtype()) return (state, timestep)