Source code for memorax.networks.sequence_models.shm

from functools import partial
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct
from flax.linen import initializers
from flax.linen.recurrent import RNNCellBase
from flax.typing import Dtype, Initializer
from jax import random

from memorax.utils.typing import Array

from memorax.networks.initializers import xavier_uniform


@struct.dataclass
class SHMConfig:
    features: int
    output_features: int
    num_thetas: int = 128
    sample_theta: bool = True
    kernel_init: Initializer = struct.field(pytree_node=False, default=nn.initializers.lecun_normal())
    bias_init: Initializer = struct.field(pytree_node=False, default=initializers.zeros_init())
    theta_init: Initializer = struct.field(pytree_node=False, default=xavier_uniform())
    dtype: Dtype | None = None
    param_dtype: Dtype = jnp.float32
    carry_init: Initializer = struct.field(pytree_node=False, default=initializers.zeros_init())


@struct.dataclass
class SHMCarry:
    memory: Array


[docs] class SHMCell(RNNCellBase): config: SHMConfig
[docs] def setup(self): dense = partial( nn.Dense, use_bias=False, dtype=self.config.dtype, param_dtype=self.config.param_dtype, kernel_init=self.config.kernel_init, bias_init=self.config.bias_init, ) self.ln = nn.LayerNorm( epsilon=1e-5, dtype=self.config.dtype, param_dtype=self.config.param_dtype, ) self.value = dense(features=self.config.features) self.key = dense(features=self.config.features) self.query = dense(features=self.config.features) self.v_c = dense(features=self.config.features) self.eta = dense(features=1) self.theta_table = self.param( "theta_table", self.config.theta_init, (self.config.num_thetas, self.config.features), self.config.param_dtype, ) self.output_projection = nn.Dense(self.config.output_features)
[docs] def __call__(self, carry: SHMCarry, inputs: Array) -> tuple[SHMCarry, Array]: inputs = self.ln(inputs) value = self.value(inputs) key = jax.nn.relu(self.key(inputs)) query = jax.nn.relu(self.query(inputs)) v_c = self.v_c(inputs) eta_val = nn.sigmoid(self.eta(inputs)) key = key / (1e-5 + jnp.sum(key, axis=-1, keepdims=True)) query = query / (1e-5 + jnp.sum(query, axis=-1, keepdims=True)) U = jnp.einsum('...i,...j->...ij', eta_val * value, key) if self.config.sample_theta and self.has_rng("torso"): rng = self.make_rng("torso") batch_shape = v_c.shape[:-1] idx = random.randint(rng, batch_shape, 0, self.config.num_thetas, dtype=jnp.int32) theta_t = self.theta_table[idx] theta_t = jnp.broadcast_to(theta_t, v_c.shape) else: theta_t = jnp.broadcast_to(self.theta_table[0], v_c.shape) C = 1.0 + jnp.tanh(jnp.einsum('...i,...j->...ij', theta_t, v_c)) M = carry.memory * C + U h = jnp.einsum("...ij,...j->...i", M, query) y = self.output_projection(h) return SHMCarry(memory=M), y
[docs] @nn.nowrap def initialize_carry(self, key: jax.Array, input_shape: tuple[int, ...]) -> SHMCarry: batch_dims = input_shape[:-1] mem_shape = batch_dims + (self.config.features, self.config.features) return SHMCarry(memory=self.config.carry_init(key, mem_shape, self.config.param_dtype))
@property def num_feature_axes(self) -> int: return 1