Source code for memorax.networks.sequence_models.min_gru

from typing import Tuple

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct
from flax.linen import initializers
from flax.typing import Dtype, Initializer

from memorax.utils.typing import Array, Carry

from .memoroid import MemoroidCellBase


@struct.dataclass
class MinGRUConfig:
    features: int
    kernel_init: Initializer = struct.field(
        pytree_node=False, default=initializers.lecun_normal()
    )
    bias_init: Initializer = struct.field(
        pytree_node=False, default=initializers.zeros_init()
    )
    dtype: Dtype | None = None
    param_dtype: Dtype = jnp.float32


@struct.dataclass
class MinGRUCarry:
    log_state: Array
    decay: Array


[docs] class MinGRUCell(MemoroidCellBase): config: MinGRUConfig
[docs] def setup(self): self.z = nn.Dense( features=self.config.features, use_bias=False, dtype=self.config.dtype, param_dtype=self.config.param_dtype, kernel_init=self.config.kernel_init, bias_init=self.config.bias_init, name="z", ) self.h = nn.Dense( features=self.config.features, use_bias=False, dtype=self.config.dtype, param_dtype=self.config.param_dtype, kernel_init=self.config.kernel_init, bias_init=self.config.bias_init, name="h", )
[docs] def __call__(self, x: Array, **kwargs) -> Carry: B, T, _ = x.shape z = self.z(x) h_tilde = self.h(x) log_z = -nn.softplus(-z) log_h_tilde = jnp.where( h_tilde >= 0, jnp.log(nn.relu(h_tilde) + 0.5), -nn.softplus(-h_tilde) ) log_state = log_z + log_h_tilde decay = -nn.softplus(z) return MinGRUCarry(log_state=log_state, decay=decay)
[docs] def binary_operator(self, a: Carry, b: Carry) -> Carry: return MinGRUCarry( log_state=jnp.logaddexp(b.decay + a.log_state, b.log_state), decay=a.decay + b.decay, )
[docs] def read(self, h: Carry, x: Array, **kwargs) -> Array: return jnp.exp(h.log_state)
[docs] def initialize_carry(self, key: jax.Array, input_shape: Tuple[int, ...]) -> Carry: *batch_dims, _ = input_shape log_state = jnp.full( (*batch_dims, 1, self.config.features), -jnp.inf, dtype=self.config.dtype or self.config.param_dtype, ) decay = jnp.zeros( (*batch_dims, 1, self.config.features), dtype=self.config.dtype or self.config.param_dtype, ) return MinGRUCarry(log_state=log_state, decay=decay)