Source code for memorax.networks.sequence_models.mamba2

from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct
from flax.typing import Dtype, Initializer

from memorax.networks.initializers import inverse_softplus, log_uniform
from memorax.utils.typing import Array, Carry

from .memoroid import MemoroidCellBase


@struct.dataclass
class Mamba2Config:
    features: int
    num_heads: int = 8
    head_dim: int = 16
    state_dim: int = 16
    num_groups: int = 1
    conv_dim: int = 4
    kernel_init: Initializer = struct.field(
        pytree_node=False, default=nn.initializers.lecun_normal()
    )
    dtype: Dtype | None = None
    param_dtype: Dtype = jnp.float32


@struct.dataclass
class Mamba2Carry:
    state: Array
    decay: Array


[docs] class Mamba2Cell(MemoroidCellBase): config: Mamba2Config
[docs] def setup(self): assert self.config.num_heads % self.config.num_groups == 0 hidden_dim = self.config.num_heads * self.config.head_dim group_projection_dim = self.config.num_groups * self.config.state_dim conv_channels = hidden_dim + 2 * group_projection_dim self.A_log = self.param("A_log", log_uniform(), (self.config.num_heads,)) self.D = self.param("D", nn.initializers.ones, (self.config.num_heads,)) self.dt_bias = self.param("dt_bias", inverse_softplus(), (self.config.num_heads,)) projection = partial( nn.Dense, kernel_init=self.config.kernel_init, use_bias=False, dtype=self.config.dtype, param_dtype=self.config.param_dtype, ) self.input_projection = projection(hidden_dim * 2) self.B = projection(group_projection_dim) self.C = projection(group_projection_dim) self.dt = projection(self.config.num_heads) self.conv = nn.Conv( conv_channels, kernel_size=(self.config.conv_dim,), padding=((self.config.conv_dim - 1, 0),), feature_group_count=conv_channels, dtype=self.config.dtype, param_dtype=self.config.param_dtype, ) self.norm = nn.RMSNorm(self.config.num_heads * self.config.head_dim) self.output_projection = nn.Dense( self.config.features, kernel_init=self.config.kernel_init, dtype=self.config.dtype, param_dtype=self.config.param_dtype, )
def _project(self, x: Array): batch_size, sequence_length, _ = x.shape hidden_dim = self.config.num_heads * self.config.head_dim group_projection_dim = self.config.num_groups * self.config.state_dim heads_per_group = self.config.num_heads // self.config.num_groups hidden, gate = jnp.split(self.input_projection(x), 2, axis=-1) B = self.B(x) C = self.C(x) conv_input = jnp.concatenate([hidden, B, C], axis=-1) conv_input = nn.silu(self.conv(conv_input)) hidden = conv_input[..., :hidden_dim].reshape( batch_size, sequence_length, self.config.num_groups, heads_per_group, self.config.head_dim, ) B = conv_input[..., hidden_dim : hidden_dim + group_projection_dim].reshape( batch_size, sequence_length, self.config.num_groups, self.config.state_dim ) C = conv_input[..., hidden_dim + group_projection_dim :].reshape( batch_size, sequence_length, self.config.num_groups, self.config.state_dim ) dt = nn.softplus(self.dt(x) + self.dt_bias).reshape( batch_size, sequence_length, self.config.num_groups, heads_per_group ) return hidden, B, C, gate, dt
[docs] def __call__(self, x: Array, **kwargs) -> Carry: hidden, B, _, _, dt = self._project(x) A = -jnp.exp(self.A_log).reshape( self.config.num_groups, self.config.num_heads // self.config.num_groups ) decay = jnp.exp(dt * A)[..., None, None] h = jnp.einsum("btgn,btgp,btgpd->btgpnd", B, dt, hidden) return Mamba2Carry(state=h, decay=decay)
[docs] def binary_operator(self, a: Carry, b: Carry) -> Carry: return Mamba2Carry( state=b.decay * a.state + b.state, decay=b.decay * a.decay, )
[docs] def read(self, carry: Carry, x: Array, **kwargs) -> Array: batch_size, sequence_length, _ = x.shape hidden, _, C, gate, _ = self._project(x) y = jnp.einsum("btgn,btgpnd->btgpd", C, carry.state) D = self.D.reshape( self.config.num_groups, self.config.num_heads // self.config.num_groups ) y = y + jnp.einsum("gp,btgpd->btgpd", D, hidden) y = y.reshape( batch_size, sequence_length, self.config.num_heads * self.config.head_dim ) y = self.norm(y) * nn.silu(gate) return self.output_projection(y)
[docs] def initialize_carry(self, key: jax.Array, input_shape: Tuple[int, ...]) -> Carry: *batch_dims, _ = input_shape heads_per_group = self.config.num_heads // self.config.num_groups state = jnp.zeros( ( *batch_dims, 1, self.config.num_groups, heads_per_group, self.config.state_dim, self.config.head_dim, ), dtype=self.config.dtype, ) decay = jnp.ones( (*batch_dims, 1, self.config.num_groups, heads_per_group, 1, 1), dtype=self.config.dtype, ) return Mamba2Carry(state=state, decay=decay)