Source code for memorax.networks.sequence_models.mlstm

"""mLSTM as a Memoroid algebra for efficient parallel computation.

The mLSTM (matrix LSTM) is a linear attention variant with learned
per-step gating. By formulating it as a Memoroid algebra, we can
use associative scan for O(log n) parallel depth instead of O(n)
sequential RNN computation.

Core recurrence:
    C_new = f * C + i * (k ⊗ v)   # matrix memory
    n_new = f * n + i * k          # normalizer
    output = (q @ C) / (q @ n)     # query

This is associative when we track cumulative decay properly.
"""

from functools import partial
from typing import Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen import initializers
from flax.linen.linear import Dense
from flax.typing import Dtype

from memorax.utils.typing import Array, Carry

from .memoroid import MemoroidCellBase
from .utils import wang_init


[docs] class mLSTMCell(MemoroidCellBase): """Matrix LSTM as a Memoroid algebra. Uses gated linear attention with matrix memory, computed efficiently via associative scan. Element: (log_f, C, n, m) where: - log_f: cumulative log forget gate for relative decay - C: matrix memory contribution (k ⊗ v scaled by input gate) - n: normalizer contribution (k scaled by input gate) - m: max log value for numerical stability Combine: Accumulates states with relative exponential decay. Attributes: features: Output feature dimension. hidden_dim: Hidden dimension (before expansion). num_heads: Number of attention heads. head_dim: Dimension per head (computed as hidden_dim / num_heads). dropout_rate: Dropout rate. dtype: Data type for computation. param_dtype: Data type for parameters. """ features: int hidden_dim: int num_heads: int = 4 dropout_rate: float = 0.0 dtype: Dtype | None = None param_dtype: Dtype = jnp.float32
[docs] @nn.compact def __call__(self, x: Array, **kwargs) -> Carry: """Compute mLSTM elements for parallel scan. Args: x: Input of shape (B, T, D) Returns: Carry tuple of (log_f, C, n, m) where: - log_f: (B, T, NH, 1, 1) cumulative log forget - C: (B, T, NH, DH, DH) matrix memory contribution - n: (B, T, NH, DH, 1) normalizer contribution - m: (B, T, NH, 1, 1) max log value for stability """ B, T, _ = x.shape head_dim = self.hidden_dim // self.num_heads if self.hidden_dim % self.num_heads != 0: raise ValueError( f"hidden_dim ({self.hidden_dim}) must be divisible by " f"num_heads ({self.num_heads})." ) # Project to hidden dimension x_proj = Dense( features=self.hidden_dim, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, name="in_proj", )(x) # Q, K, V projections projection = partial( Dense, features=self.hidden_dim, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, ) q = projection(name="q")(x_proj) k = projection(name="k")(x_proj) v = projection(name="v")(x_proj) # Reshape to heads: (B, T, NH, DH) q = q.reshape(B, T, self.num_heads, head_dim) k = k.reshape(B, T, self.num_heads, head_dim) v = v.reshape(B, T, self.num_heads, head_dim) # Gate projections from concatenated qkv qkv = jnp.concatenate([q, k, v], axis=-1) # (B, T, NH, 3*DH) qkv_flat = qkv.reshape(B, T, -1) # (B, T, NH*3*DH) gate = partial( Dense, features=self.num_heads, kernel_init=initializers.zeros_init(), dtype=self.dtype, param_dtype=self.param_dtype, ) i_gate = gate(name="wi", bias_init=initializers.normal(stddev=0.1))(qkv_flat) f_gate = gate(name="wf", bias_init=initializers.constant(4.0))(qkv_flat) # Reshape gates: (B, T, NH) i_gate = i_gate.reshape(B, T, self.num_heads) f_gate = f_gate.reshape(B, T, self.num_heads) # Log-space gates for numerical stability log_f = -jax.nn.softplus(-f_gate) # (B, T, NH) log_i = i_gate # Keep in log space # Compute max for stability: m = max(log_i) m = log_i[:, :, :, None, None] # (B, T, NH, 1, 1) # Stable input gate i_stable = jnp.exp(log_i - m.squeeze(-1).squeeze(-1)) # (B, T, NH) # Scale k and v by sqrt for numerical stability k = k / jnp.sqrt(head_dim) # Compute contributions # C_contribution = i * (k ⊗ v): (B, T, NH, DH, DH) k_col = k[:, :, :, :, None] # (B, T, NH, DH, 1) v_row = v[:, :, :, None, :] # (B, T, NH, 1, DH) kv_outer = k_col @ v_row # (B, T, NH, DH, DH) i_expanded = i_stable[:, :, :, None, None] # (B, T, NH, 1, 1) C = i_expanded * kv_outer # (B, T, NH, DH, DH) # n_contribution = i * k: (B, T, NH, DH, 1) n = i_stable[:, :, :, None] * k # (B, T, NH, DH) n = n[:, :, :, :, None] # (B, T, NH, DH, 1) # Reshape log_f for broadcasting: (B, T, NH, 1, 1) log_f = log_f[:, :, :, None, None] return (log_f, C, n, m)
[docs] def binary_operator(self, a: Carry, b: Carry) -> Carry: """Combine two elements with decay-weighted accumulation. When combining element a (earlier positions) with element b (later positions), a's state decays by b's cumulative forget gate. Args: a: Earlier element (log_f_a, C_a, n_a, m_a) b: Later element (log_f_b, C_b, n_b, m_b) Returns: Combined element with accumulated state. """ log_f_a, C_a, n_a, m_a = a log_f_b, C_b, n_b, m_b = b # Combined cumulative decay log_f_combined = log_f_a + log_f_b # Numerical stability: find max of decayed a and b m_a_decayed = m_a + log_f_b # a's contribution after decay m_combined = jnp.maximum(m_a_decayed, m_b) # Compute stable scaling factors scale_a = jnp.exp(m_a_decayed - m_combined) scale_b = jnp.exp(m_b - m_combined) # Combine states C_combined = scale_a * C_a + scale_b * C_b n_combined = scale_a * n_a + scale_b * n_b return (log_f_combined, C_combined, n_combined, m_combined)
[docs] @nn.compact def read(self, h: Carry, x: Array, **kwargs) -> Array: """Query accumulated memory to produce output. Args: h: Accumulated state (log_f, C, n, m) x: Original input of shape (B, T, D) Returns: Output of shape (B, T, features) """ B, T, in_features = x.shape head_dim = self.hidden_dim // self.num_heads # Project input for query x_proj = Dense( features=self.hidden_dim, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, name="in_proj", )(x) q = Dense( features=self.hidden_dim, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, name="q", )(x_proj) # Reshape to heads: (B, T, NH, DH) q = q.reshape(B, T, self.num_heads, head_dim) q = q / jnp.sqrt(head_dim) # Extract accumulated state _, C, n, m = h # Query the memory: h_tilde = (q @ C) / normalize(q @ n) q_row = q[:, :, :, None, :] # (B, T, NH, 1, DH) # q @ C: (B, T, NH, 1, DH) @ (B, T, NH, DH, DH) -> (B, T, NH, 1, DH) qC = (q_row @ C).squeeze(-2) # (B, T, NH, DH) # q @ n: (B, T, NH, 1, DH) @ (B, T, NH, DH, 1) -> (B, T, NH, 1, 1) qn = (q_row @ n).squeeze(-2).squeeze(-1) # (B, T, NH) # Normalize normalizer = jnp.maximum(jnp.abs(qn), 1.0)[:, :, :, None] h_tilde = qC / (normalizer + 1e-6) # Reshape and project output h_tilde = h_tilde.reshape(B, T, self.hidden_dim) # Output projection y = Dense( features=self.features, use_bias=False, kernel_init=wang_init(self.hidden_dim, num_blocks=1), dtype=self.dtype, param_dtype=self.param_dtype, name="out_proj", )(h_tilde) # Dropout y = nn.Dropout( rate=self.dropout_rate, deterministic=not self.has_rng("dropout") )(y) return y
[docs] def initialize_carry(self, key: jax.Array, input_shape: Tuple[int, ...]) -> Carry: """Initialize carry with identity decay and zero state. Args: key: Random key (unused). input_shape: Shape of input (batch_size, features). Returns: Initial carry tuple (log_f, C, n, m). """ batch_size, *_ = input_shape head_dim = self.hidden_dim // self.num_heads # Initial cumulative log decay is 0 (no decay yet) log_f = jnp.zeros((batch_size, 1, self.num_heads, 1, 1)) # Initial state is zero C = jnp.zeros((batch_size, 1, self.num_heads, head_dim, head_dim)) n = jnp.zeros((batch_size, 1, self.num_heads, head_dim, 1)) # Initial max is -inf (will be overwritten on first combine) m = jnp.full((batch_size, 1, self.num_heads, 1, 1), -1e9) return (log_f, C, n, m)