Source code for memorax.networks.layers.multi_head_layer_norm
import jax.numpy as jnp
from flax import linen as nn
from flax.typing import Dtype
from memorax.utils.typing import Array
[docs]
class MultiHeadLayerNorm(nn.Module):
eps: float = 1e-5
use_scale: bool = True
use_bias: bool = False
residual_weight: bool = True
dtype: Dtype | None = None
param_dtype: Dtype = jnp.float32
[docs]
@nn.compact
def __call__(self, x) -> Array:
B, NH, S, DH = x.shape
y = nn.vmap(
nn.LayerNorm,
variable_axes={"params": 0},
split_rngs={"params": True},
in_axes=1,
out_axes=1,
)(
epsilon=self.eps,
use_scale=False,
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
)(
x
)
if self.use_scale:
gamma = self.param(
"weight", nn.initializers.zeros_init(), (NH, DH), self.param_dtype
)
scale = (1.0 + gamma) if self.residual_weight else gamma
y = y * scale[None, :, None, :].astype(y.dtype)
if self.use_bias:
beta = self.param(
"bias", nn.initializers.zeros_init(), (NH, DH), self.param_dtype
)
y = y + beta[None, :, None, :].astype(y.dtype)
return y