Source code for memorax.networks.sequence_models.shm

from functools import partial
from typing import Any, TypeVar

import flax.linen as nn
import jax
from flax.linen import LayerNorm, initializers
from flax.linen.activation import sigmoid
from flax.linen.linear import Dense, default_kernel_init
from flax.linen.module import compact, nowrap
from flax.linen.recurrent import RNNCellBase
from flax.typing import Array, Dtype, Initializer, PRNGKey
from jax import numpy as jnp
from jax import random
from memorax.networks.sequence_models.utils import xavier_uniform

A = TypeVar("A")
Carry = Any
CarryHistory = Any
Output = Any


[docs] class SHMCell(RNNCellBase): """Stable Hadamard Memory (SHM) cell.""" features: int output_features: int num_thetas: int = 128 sample_theta: bool = True kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init() theta_init: Initializer = xavier_uniform() dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 carry_init: Initializer = initializers.zeros_init()
[docs] @compact def __call__(self, carry: Array, inputs: Array) -> tuple[Array, Array]: dense = partial( Dense, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, ) inputs = LayerNorm( epsilon=1e-5, dtype=self.dtype, param_dtype=self.param_dtype, name="ln" )(inputs) v = dense(name="value", features=self.features)(inputs) k = jax.nn.relu(dense(name="key", features=self.features)(inputs)) q = jax.nn.relu(dense(name="query", features=self.features)(inputs)) v_c = dense(name="vc", features=self.features)(inputs) eta = sigmoid(dense(name="eta", features=1)(inputs)) k = k / (1e-5 + jnp.sum(k, axis=-1, keepdims=True)) q = q / (1e-5 + jnp.sum(q, axis=-1, keepdims=True)) U = ((eta * v)[..., :, None]) * k[..., None, :] theta_table = self.param( "theta_table", self.theta_init, (self.num_thetas, self.features), self.param_dtype, ) if self.sample_theta and self.has_rng("memory"): rng = self.make_rng("memory") batch_shape = v_c.shape[:-1] idx = random.randint(rng, batch_shape, 0, self.num_thetas, dtype=jnp.int32) theta_t = theta_table[idx] theta_t = jnp.broadcast_to(theta_t, v_c.shape) else: theta_t = jnp.broadcast_to(theta_table[0], v_c.shape) C = 1.0 + jnp.tanh(theta_t[..., :, None] * v_c[..., None, :]) carry = carry * C + U h = jnp.einsum("...ij,...j->...i", carry, q) y = nn.Dense(self.output_features, name="output")(h) return carry, y
[docs] @nowrap def initialize_carry(self, rng: PRNGKey, input_shape: tuple[int, ...]) -> Array: batch_dims = input_shape[:-1] mem_shape = batch_dims + (self.features, self.features) return self.carry_init(rng, mem_shape, self.param_dtype)
@property def num_feature_axes(self) -> int: return 1