Source code for memorax.networks.blocks.ffn
from typing import Callable, Optional
import flax.linen as nn
import jax.numpy as jnp
from memorax.utils.typing import Array, Carry
from .base import Block
[docs]
class FFN(nn.Module, Block):
"""Standard feed-forward network: Dense -> Activation -> Dense."""
features: int
expansion_factor: int = 4
activation: Callable = nn.gelu
dropout_rate: float = 0.0
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(
self,
inputs: Array,
mask: Optional[Array] = None,
initial_carry: Optional[Carry] = None,
**kwargs,
) -> tuple[Carry, Array]:
hidden_dim = self.features * self.expansion_factor
x = nn.Dense(
hidden_dim,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)(inputs)
x = self.activation(x)
x = nn.Dropout(
rate=self.dropout_rate, deterministic=not self.has_rng("dropout")
)(x)
x = nn.Dense(
self.features,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)(x)
return None, x
[docs]
@nn.nowrap
def initialize_carry(self, key, input_shape):
return None
class GatedFFN(nn.Module, Block):
"""Gated feed-forward network (SwiGLU-style): Dense -> split -> act(gate) * value -> Dense."""
features: int
expansion_factor: int = 4
activation: Callable = nn.gelu
dropout_rate: float = 0.0
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = False
@nn.compact
def __call__(
self,
inputs: Array,
mask: Optional[Array] = None,
initial_carry: Optional[Carry] = None,
**kwargs,
) -> tuple[Carry, Array]:
hidden_dim = self.features * self.expansion_factor
x = nn.Dense(
2 * hidden_dim,
kernel_init=self.kernel_init,
use_bias=self.use_bias,
name="up_proj",
)(inputs)
gate, value = jnp.split(x, 2, axis=-1)
x = self.activation(gate) * value
x = nn.Dropout(
rate=self.dropout_rate, deterministic=not self.has_rng("dropout")
)(x)
x = nn.Dense(
self.features,
kernel_init=self.kernel_init,
use_bias=self.use_bias,
name="down_proj",
)(x)
return None, x
@nn.nowrap
def initialize_carry(self, key, input_shape):
return None