Source code for memorax.networks.mlp

from typing import Callable, Optional, Sequence

import flax.linen as nn
import jax.numpy as jnp


[docs] class MLP(nn.Module): features: int | Sequence[int] activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu normalizer: Optional[Callable] = None kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray: if isinstance(self.features, int): features = [self.features] else: features = self.features for feature in features: x = nn.Dense( feature, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) if self.normalizer is not None: x = self.normalizer()(x) x = self.activation(x) return x