Source code for memorax.networks.sequence_models.linear_attention

from functools import partial
from typing import Optional, Tuple

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.typing import Dtype, Initializer

from memorax.utils.typing import Array, Carry

from .memoroid import MemoroidCellBase


[docs] class LinearAttentionCell(MemoroidCellBase): """Linear attention as a memoroid algebra. Uses kernel feature maps (ELU+1) to linearize attention, enabling efficient parallel computation via associative scan. Based on "Transformers are RNNs" (Katharopoulos et al., 2020). Element: (state, normalizer) where: - state: outer product Σ φ(k) ⊗ v - normalizer: sum of keys Σ φ(k) Combine: element-wise addition of states and normalizers """ head_dim: int num_heads: int kernel_init: Initializer bias_init: Initializer param_dtype: Dtype dtype: Optional[Dtype] = None eps: float = 1e-6 def _feature_map(self, x: Array) -> Array: """ELU+1 feature map as in the original paper.""" return nn.elu(x) + 1
[docs] @nn.compact def __call__(self, x: Array, **kwargs) -> Carry: """Compute key-value outer products for memory storage. Args: x: Input of shape (B, T, D) Returns: Carry tuple of (state, normalizer) where: - state: outer product (B, T, H, head_dim, head_dim) - normalizer: sum of keys (B, T, H, head_dim) """ batch_size, sequence_length, _ = x.shape projection = partial( nn.DenseGeneral, features=(self.num_heads, self.head_dim), use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, ) key = projection(name="key")(x) value = projection(name="value")(x) # Apply feature map to keys key = self._feature_map(key) # State: outer product v ⊗ φ(k) state = jnp.einsum("bthi,bthj->bthij", value, key) # Normalizer: sum of φ(k) for proper normalization normalizer = key return (state, normalizer)
[docs] def binary_operator(self, a: Carry, b: Carry) -> Carry: """Combine two elements via addition.""" state_i, norm_i = a state_j, norm_j = b return (state_i + state_j, norm_i + norm_j)
[docs] @nn.compact def read(self, h: Carry, x: Array, **kwargs) -> Array: """Query accumulated memory to produce output. Args: h: Accumulated state (state, normalizer) x: Original input of shape (B, T, D) Returns: Output of shape (B, T, D) """ batch_size, sequence_length, in_features = x.shape projection = partial( nn.DenseGeneral, features=(self.num_heads, self.head_dim), use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, ) query = projection(name="query")(x) query = self._feature_map(query) state, normalizer = h # Numerator: φ(q) @ S = φ(q) @ (Σ v ⊗ φ(k)) numerator = jnp.einsum("bthij,bthj->bthi", state, query) # Denominator: φ(q) @ z = φ(q) @ (Σ φ(k)) denominator = jnp.einsum("bthi,bthi->bth", query, normalizer) denominator = jnp.maximum(denominator, self.eps)[:, :, :, None] # Normalized output output = numerator / denominator hidden_dim = self.num_heads * self.head_dim output = output.reshape(batch_size, sequence_length, hidden_dim) output = nn.RMSNorm(dtype=self.dtype)(output) output = nn.Dense( features=in_features, kernel_init=self.kernel_init, bias_init=self.bias_init, )(output) return output
[docs] def initialize_carry(self, key: jax.Array, input_shape: Tuple[int, ...]) -> Carry: """Initialize carry with zero state and normalizer.""" batch_size, *_ = input_shape state = jnp.zeros( (batch_size, 1, self.num_heads, self.head_dim, self.head_dim) ) normalizer = jnp.zeros((batch_size, 1, self.num_heads, self.head_dim)) return (state, normalizer)