Source code for memorax.networks.blocks.ffn

from typing import Callable

import flax.linen as nn
import jax.numpy as jnp

from memorax.utils.typing import Array, Carry, Key

from .base import Block


[docs] class FFN(nn.Module, Block): """Standard feed-forward network: Dense -> Activation -> Dense.""" features: int expansion_factor: int = 4 activation: Callable = nn.gelu dropout_rate: float = 0.0 kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__( self, inputs: Array, done: Array | None = None, initial_carry: Carry | None = None, **kwargs, ) -> tuple[Carry, Array]: hidden_dim = self.features * self.expansion_factor x = nn.Dense( hidden_dim, kernel_init=self.kernel_init, bias_init=self.bias_init, )(inputs) x = self.activation(x) x = nn.Dropout( rate=self.dropout_rate, deterministic=not self.has_rng("dropout") )(x) x = nn.Dense( self.features, kernel_init=self.kernel_init, bias_init=self.bias_init, )(x) return None, x
[docs] @nn.nowrap def initialize_carry(self, key: Key, input_shape: tuple) -> None: return None
[docs] class Projection(nn.Module, Block): """Single linear projection with optional activation.""" features: int use_bias: bool = True kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init() activation_fn: Callable | None = None
[docs] @nn.compact def __call__( self, inputs: Array, done: Array | None = None, initial_carry: Carry | None = None, **kwargs, ) -> tuple[Carry, Array]: x = nn.Dense( self.features, use_bias=self.use_bias, kernel_init=self.kernel_init, bias_init=self.bias_init, )(inputs) if self.activation_fn is not None: x = self.activation_fn(x) return None, x
[docs] @nn.nowrap def initialize_carry(self, key: Key, input_shape: tuple) -> None: return None
[docs] class GLU(nn.Module, Block): features: int expansion_factor: int = 4 activation: Callable = nn.gelu dropout_rate: float = 0.0 kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() use_bias: bool = False
[docs] @nn.compact def __call__( self, inputs: Array, done: Array | None = None, initial_carry: Carry | None = None, **kwargs, ) -> tuple[Carry, Array]: hidden_dim = self.features * self.expansion_factor x = nn.Dense( 2 * hidden_dim, kernel_init=self.kernel_init, use_bias=self.use_bias, name="up_proj", )(inputs) gate, value = jnp.split(x, 2, axis=-1) x = self.activation(gate) * value x = nn.Dropout( rate=self.dropout_rate, deterministic=not self.has_rng("dropout") )(x) x = nn.Dense( self.features, kernel_init=self.kernel_init, use_bias=self.use_bias, name="down_proj", )(x) return None, x
[docs] @nn.nowrap def initialize_carry(self, key: Key, input_shape: tuple) -> None: return None