Sequence Models#

Memorax supports three families of sequence models for memory-augmented RL.

Overview#

Family

Models

Strengths

RNNs

LSTM, GRU, sLSTM, SHM, RTU

Simple, well-understood

State Space Models

S5, LRU, Mamba, MinGRU, mLSTM, FFM

Efficient long sequences via parallel scan

Attention

SelfAttention, LinearAttention

Flexible, parallel training

RNNs#

LSTM / GRU#

Standard recurrent networks using Flax cells:

import flax.linen as nn
from memorax.networks import RNN

lstm_torso = RNN(cell=nn.LSTMCell(features=64))
gru_torso = RNN(cell=nn.GRUCell(features=64))

sLSTM#

Scalar LSTM with enhanced gating and feature normalization:

from memorax.networks import RNN, sLSTMCell, sLSTMConfig

slstm = RNN(cell=sLSTMCell(config=sLSTMConfig(features=64, hidden_dim=64)))

SHM (Stable Hadamard Memory)#

from memorax.networks import RNN, SHMCell, SHMConfig

shm = RNN(cell=SHMCell(config=SHMConfig(features=64, output_features=32)))

State Space Models#

All state space models use the Memoroid wrapper for parallel scan execution.

LRU (Linear Recurrent Unit)#

from memorax.networks import Memoroid, LRUCell, LRUConfig

lru = Memoroid(cell=LRUCell(config=LRUConfig(features=64, hidden_dim=64)))

S5 (Simplified Structured State Space)#

from memorax.networks import Memoroid, S5Cell

s5 = Memoroid(cell=S5Cell(config=S5Config(features=64, hidden_dim=64)))

Mamba (Selective State Space Model)#

from memorax.networks import Memoroid, Mamba2Cell, Mamba2Config

mamba = Memoroid(cell=Mamba2Cell(config=Mamba2Config(features=64, num_heads=4, head_dim=16)))

MinGRU#

Minimal GRU variant computed in log-space for numerical stability:

from memorax.networks import Memoroid, MinGRUCell, MinGRUConfig

mingru = Memoroid(cell=MinGRUCell(config=MinGRUConfig(features=64)))

mLSTM (Matrix LSTM)#

Matrix LSTM using gated linear attention:

from memorax.networks import Memoroid, mLSTMCell, mLSTMConfig

mlstm = Memoroid(cell=mLSTMCell(config=mLSTMConfig(features=64, hidden_dim=64, num_heads=4)))

FFM (Fast and Forgetful Memory)#

from memorax.networks import Memoroid, FFMCell, FFMConfig

ffm = Memoroid(cell=FFMCell(config=FFMConfig(features=64, memory_size=32, context_size=64)))

RTU (Rotational Transformation Unit)#

from memorax.networks import RNN, RTUCell, RTUConfig

rtu = RNN(cell=RTUCell(config=RTUConfig(features=64, hidden_dim=64)))

Attention#

Self-Attention#

Multi-head self-attention (used directly, no wrapper needed):

from memorax.networks import SelfAttention, SelfAttentionConfig

attention = SelfAttention(config=SelfAttentionConfig(features=64, num_heads=4, context_length=128))

Linear Attention#

Efficient linear-complexity attention via kernelized features:

from memorax.networks import Memoroid, LinearAttentionCell, LinearAttentionConfig

linear_attention = Memoroid(cell=LinearAttentionCell(config=LinearAttentionConfig(features=64, num_heads=4, head_dim=16)))

RTRL (Real-Time Recurrent Learning)#

Wraps any sequence model to compute real-time gradients through the recurrence:

from memorax.networks import RTRL, RNN
import flax.linen as nn

rtrl_gru = RTRL(model=RNN(cell=nn.GRUCell(features=64)))

RL2 Wrapper#

Preserves hidden state across episode boundaries within a trial for meta-RL:

from memorax.networks import RL2Wrapper, RNN
import flax.linen as nn

rl2 = RL2Wrapper(model=RNN(cell=nn.GRUCell(features=64)))

Choosing a Model#

For Short Episodes (< 100 steps)#

  • LSTM/GRU: Simple and effective

  • sLSTM: Enhanced gating

For Long Episodes (100-1000 steps)#

  • S5/LRU: Efficient state space models

  • Mamba: Selective attention to inputs

  • RTU: Efficient rotational dynamics

For Very Long Episodes (> 1000 steps)#

  • SelfAttention: With positional embeddings

  • LinearAttention: Linear complexity

For Memory-Intensive Tasks#

  • FFM/SHM: Explicit memory mechanisms

  • mLSTM: Matrix memory

Example: Mamba Agent#

from memorax.algorithms import PPO, PPOConfig
from memorax.networks import Network, FeatureExtractor, Memoroid, Mamba2Cell, Mamba2Config, heads
from memorax.networks.layers import Flatten

actor = Network(
    feature_extractor=FeatureExtractor(observation_extractor=Flatten()),
    torso=Memoroid(cell=Mamba2Cell(config=Mamba2Config(features=64, num_heads=4, head_dim=16))),
    head=heads.Categorical(action_dim=4),
)

critic = Network(
    feature_extractor=FeatureExtractor(observation_extractor=Flatten()),
    torso=Memoroid(cell=Mamba2Cell(config=Mamba2Config(features=64, num_heads=4, head_dim=16))),
    head=heads.VNetwork(),
)

agent = PPO(config, env, env_params, actor, critic, optimizer, optimizer)