Source code for memorax.networks.blocks.residual

from typing import Callable, Optional

import flax.linen as nn

from memorax.utils.typing import Array, Carry

from .base import Block


[docs] class Residual(nn.Module, Block): """Wraps a module with a residual connection: output = x + module(x).""" module: nn.Module
[docs] @nn.compact def __call__( self, inputs: Array, mask: Optional[Array] = None, initial_carry: Optional[Carry] = None, **kwargs, ) -> tuple[Carry, Array]: carry, output = self.module( inputs, mask=mask, initial_carry=initial_carry, **kwargs ) return carry, inputs + output
[docs] @nn.nowrap def initialize_carry(self, key, input_shape): return self.module.initialize_carry(key, input_shape)
class GatedResidual(nn.Module, Block): """Residual connection with a learned gate: output = x + gate * module(x).""" module: nn.Module gate: Callable = nn.sigmoid @nn.compact def __call__( self, inputs: Array, mask: Optional[Array] = None, initial_carry: Optional[Carry] = None, **kwargs, ) -> tuple[Carry, Array]: features = inputs.shape[-1] carry, output = self.module( inputs, mask=mask, initial_carry=initial_carry, **kwargs ) gate = nn.Dense( features, name="gate", )(inputs) gate = self.gate(gate) return carry, inputs + gate * output @nn.nowrap def initialize_carry(self, key, input_shape): return self.module.initialize_carry(key, input_shape)