Source code for memorax.networks.layers.block_diagonal_dense

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.typing import Dtype, Initializer

from memorax.networks.initializers import small
from memorax.utils.typing import Array


[docs] class BlockDiagonalDense(nn.Module): features: int num_heads: int use_bias: bool = True kernel_init: Initializer | None = None bias_init: Initializer = nn.initializers.zeros_init() dtype: Dtype | None = None param_dtype: Dtype = jnp.float32
[docs] @nn.compact def __call__(self, x: Array) -> Array: *batch, features = x.shape block_size = features // self.num_heads kernel_init = self.kernel_init or small(block_size) kernel = self.param( "kernel", kernel_init, (self.num_heads, block_size, block_size), self.param_dtype, ) x = x.reshape(*batch, self.num_heads, -1) x = jnp.einsum("...hd,hod->...ho", x, kernel) x = x.reshape(*batch, -1) if self.use_bias: bias = self.param( "bias", self.bias_init, (self.features,), self.param_dtype, ) bias = jnp.broadcast_to(bias, x.shape) x = x + bias return x