Source code for memorax.networks.blocks.stack
from typing import Optional, Sequence
import flax.linen as nn
from memorax.utils.typing import Array, Carry
from .base import Block
[docs]
class Stack(nn.Module, Block):
"""Vertically stacks multiple heterogeneous blocks.
Each block's output becomes the next block's input. Carry states are
maintained per-block as a tuple, allowing different block types with
different carry structures to be composed.
Args:
blocks: Sequence of blocks to stack. Each must implement the Block protocol.
Example:
stack = Stack(blocks=(
Residual(module=PreNorm(module=SelfAttention(...))),
Residual(module=PreNorm(module=FFN(...))),
Residual(module=PreNorm(module=SelfAttention(...))),
Residual(module=PreNorm(module=FFN(...))),
))
carry, output = stack(inputs, mask, initial_carry)
"""
blocks: Sequence[nn.Module]
[docs]
@nn.compact
def __call__(
self,
inputs: Array,
mask: Optional[Array] = None,
initial_carry: Optional[tuple[Carry, ...]] = None,
**kwargs,
) -> tuple[tuple[Carry, ...], Array]:
if initial_carry is None:
initial_carry = tuple(None for _ in self.blocks)
x = inputs
carries = []
for i, block in enumerate(self.blocks):
carry, x = block(x, mask=mask, initial_carry=initial_carry[i], **kwargs)
carries.append(carry)
return tuple(carries), x
[docs]
@nn.nowrap
def initialize_carry(self, key, input_shape):
carries = []
for block in self.blocks:
if hasattr(block, "initialize_carry"):
carries.append(block.initialize_carry(key, input_shape))
else:
carries.append(None)
return tuple(carries)