from typing import Callable
import flax.linen as nn
import jax.numpy as jnp
from memorax.utils.typing import Array, Carry, Key
from .base import Block
[docs]
class FFN(nn.Module, Block):
"""Standard feed-forward network: Dense -> Activation -> Dense."""
features: int
expansion_factor: int = 4
activation: Callable = nn.gelu
dropout_rate: float = 0.0
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(
self,
inputs: Array,
done: Array | None = None,
initial_carry: Carry | None = None,
**kwargs,
) -> tuple[Carry, Array]:
hidden_dim = self.features * self.expansion_factor
x = nn.Dense(
hidden_dim,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)(inputs)
x = self.activation(x)
x = nn.Dropout(
rate=self.dropout_rate, deterministic=not self.has_rng("dropout")
)(x)
x = nn.Dense(
self.features,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)(x)
return None, x
[docs]
@nn.nowrap
def initialize_carry(self, key: Key, input_shape: tuple) -> None:
return None
[docs]
class Projection(nn.Module, Block):
"""Single linear projection with optional activation."""
features: int
use_bias: bool = True
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
activation_fn: Callable | None = None
[docs]
@nn.compact
def __call__(
self,
inputs: Array,
done: Array | None = None,
initial_carry: Carry | None = None,
**kwargs,
) -> tuple[Carry, Array]:
x = nn.Dense(
self.features,
use_bias=self.use_bias,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)(inputs)
if self.activation_fn is not None:
x = self.activation_fn(x)
return None, x
[docs]
@nn.nowrap
def initialize_carry(self, key: Key, input_shape: tuple) -> None:
return None
[docs]
class GLU(nn.Module, Block):
features: int
expansion_factor: int = 4
activation: Callable = nn.gelu
dropout_rate: float = 0.0
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
use_bias: bool = False
[docs]
@nn.compact
def __call__(
self,
inputs: Array,
done: Array | None = None,
initial_carry: Carry | None = None,
**kwargs,
) -> tuple[Carry, Array]:
hidden_dim = self.features * self.expansion_factor
x = nn.Dense(
2 * hidden_dim,
kernel_init=self.kernel_init,
use_bias=self.use_bias,
name="up_proj",
)(inputs)
gate, value = jnp.split(x, 2, axis=-1)
x = self.activation(gate) * value
x = nn.Dropout(
rate=self.dropout_rate, deterministic=not self.has_rng("dropout")
)(x)
x = nn.Dense(
self.features,
kernel_init=self.kernel_init,
use_bias=self.use_bias,
name="down_proj",
)(x)
return None, x
[docs]
@nn.nowrap
def initialize_carry(self, key: Key, input_shape: tuple) -> None:
return None