memorax.networks.sequence_models

memorax.networks.sequence_models#

Sequence models for temporal processing.

RNN Models#

RNN - Wrapper for Flax RNN cells (LSTM, GRU, etc.).

sLSTMCell - Scalar LSTM cell with enhanced gating and feature normalization.

SHMCell - Stable Hadamard Memory cell.

Memoroid Models#

Memoroid - Wrapper for parallel-scannable sequence models using associative scan.

MemoroidCellBase - Base class for memoroid cells.

Mamba2Cell - Selective State Space Model cell (Mamba-2).

Mamba3Cell - State Space Model with trapezoidal discretization and complex state (Mamba-3).

S5Cell - Simplified Structured State Space cell.

LRUCell - Linear Recurrent Unit cell.

MinGRUCell - Minimal GRU cell (log-space).

mLSTMCell - Matrix LSTM cell with gated linear attention.

FFMCell - Fast and Forgetful Memory cell.

LinearAttentionCell - Linear attention cell with kernelized features.

RTUCell - Rotational Transformation Unit cell.

Attention#

SelfAttention - Multi-head self-attention with optional cross-segment memory.

Wrappers#

SequenceModelWrapper - Wraps non-recurrent models as sequence models.

RL2Wrapper - RL² wrapper that preserves hidden state across episode boundaries.

RTRL - Real-Time Recurrent Learning wrapper for online gradient computation.