Source code for memorax.networks.sequence_models.min_gru

from functools import partial
from typing import Any, Callable, Optional, Tuple

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.linen import initializers
from memorax.utils.typing import Array, Carry

from .memoroid import MemoroidCellBase

PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any
Initializer = Callable[[PRNGKey, Shape, Dtype], Array]


[docs] class MinGRUCell(MemoroidCellBase): """Minimal GRU algebra using log-space computation. Operates in log-space to avoid numerical overflow for long sequences. Element: (log_state, cumulative_decay) Combine: (logaddexp(decay_j + log_state_i, log_state_j), decay_i + decay_j) """ features: int kernel_init: Initializer = initializers.lecun_normal() bias_init: Initializer = initializers.zeros_init() dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32
[docs] @nn.compact def __call__(self, inputs: Array, **kwargs) -> Carry: B, T, _ = inputs.shape dense = partial( nn.Dense, features=self.features, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, ) z = dense(name="z")(inputs) h_tilde = dense(name="h")(inputs) # Compute in log-space log_z = -nn.softplus(-z) log_h_tilde = jnp.where( h_tilde >= 0, jnp.log(nn.relu(h_tilde) + 0.5), -nn.softplus(-h_tilde) ) log_state = log_z + log_h_tilde decay = -nn.softplus(z) return (log_state, decay)
[docs] def binary_operator(self, a: Carry, b: Carry) -> Carry: """Log-space combine: logaddexp for numerically stable addition.""" log_state_i, decay_i = a log_state_j, decay_j = b return ( jnp.logaddexp(decay_j + log_state_i, log_state_j), decay_i + decay_j, )
[docs] def read(self, h: Carry, x: Array, **kwargs) -> Array: log_state, _ = h return jnp.exp(log_state)
[docs] def initialize_carry(self, key: jax.Array, input_shape: Tuple[int, ...]) -> Carry: batch_size, *_ = input_shape # Identity for logaddexp is -inf, identity for + is 0 log_state = jnp.full( (batch_size, 1, self.features), -jnp.inf, dtype=self.dtype or self.param_dtype, ) decay = jnp.zeros( (batch_size, 1, self.features), dtype=self.dtype or self.param_dtype ) return (log_state, decay)