Source code for memorax.networks.blocks.normalization

from typing import Callable, Optional

import flax.linen as nn

from memorax.utils.typing import Array, Carry

from .base import Block


[docs] class PreNorm(nn.Module, Block): """Applies normalization before the module: output = module(norm(x)). Args: module: The module to wrap. norm: Normalization class (default: nn.LayerNorm). norm_kwargs: Additional kwargs passed to the norm constructor. """ module: nn.Module norm: Callable = nn.LayerNorm
[docs] @nn.compact def __call__( self, inputs: Array, mask: Optional[Array] = None, initial_carry: Optional[Carry] = None, **kwargs, ) -> tuple[Carry, Array]: x = self.norm()(inputs) return self.module(x, mask=mask, initial_carry=initial_carry, **kwargs)
[docs] @nn.nowrap def initialize_carry(self, key, input_shape): return self.module.initialize_carry(key, input_shape)
[docs] class PostNorm(nn.Module, Block): """Applies normalization after the module: output = norm(module(x)). Args: module: The module to wrap. norm: Normalization class (default: nn.LayerNorm). norm_kwargs: Additional kwargs passed to the norm constructor. """ module: nn.Module norm: Callable = nn.LayerNorm
[docs] @nn.compact def __call__( self, inputs: Array, mask: Optional[Array] = None, initial_carry: Optional[Carry] = None, **kwargs, ) -> tuple[Carry, Array]: carry, output = self.module( inputs, mask=mask, initial_carry=initial_carry, **kwargs ) return carry, self.norm()(output)
[docs] @nn.nowrap def initialize_carry(self, key, input_shape): return self.module.initialize_carry(key, input_shape)