Sequence Models#
This guide covers the available sequence models for memory-augmented RL.
Overview#
Memorax supports three families of sequence models:
Family |
Models |
Strengths |
|---|---|---|
RNNs |
LSTM, GRU, sLSTM, SHM |
Simple, well-understood |
State Space Models |
S5, LRU, Mamba, MinGRU, mLSTM, FFM |
Efficient long sequences |
Attention |
SelfAttention, LinearAttention |
Flexible, parallel training |
RNNs#
LSTM / GRU#
Standard recurrent networks using Flax cells:
import flax.linen as nn
from memorax.networks import RNN
# LSTM
lstm_torso = RNN(cell=nn.LSTMCell(features=64))
# GRU
gru_torso = RNN(cell=nn.GRUCell(features=64))
sLSTM#
Scalar LSTM with enhanced gating:
from memorax.networks import RNN, sLSTMCell
slstm = RNN(cell=sLSTMCell(features=64))
SHM (Stable Hadamard Memory)#
from memorax.networks import RNN, SHMCell
shm = RNN(cell=SHMCell(features=64, memory_size=32))
State Space Models#
All state space models use the Memoroid wrapper with their respective cells.
LRU (Linear Recurrent Unit)#
Efficient linear recurrence:
from memorax.networks import Memoroid, LRUCell
lru = Memoroid(cell=LRUCell(features=64, hidden_dim=64))
S5#
Simplified Structured State Space:
from memorax.networks import Memoroid, S5Cell
s5 = Memoroid(cell=S5Cell(features=64, state_dim=64))
Mamba#
Selective State Space Model:
from memorax.networks import Memoroid, MambaCell
mamba = Memoroid(cell=MambaCell(features=64, num_heads=4, head_dim=16))
MinGRU#
Minimal GRU variant:
from memorax.networks import Memoroid, MinGRUCell
mingru = Memoroid(cell=MinGRUCell(features=64))
mLSTM (Matrix LSTM)#
from memorax.networks import Memoroid, mLSTMCell
mlstm = Memoroid(cell=mLSTMCell(features=64, num_heads=4, head_dim=16))
FFM (Fast and Forgetful Memory)#
from memorax.networks import Memoroid, FFMCell
ffm = Memoroid(cell=FFMCell(features=64, memory_size=32))
Attention#
Self-Attention#
Standard multi-head attention (used directly, no wrapper needed):
from memorax.networks import SelfAttention
attention = SelfAttention(
features=64,
num_heads=4,
head_dim=16,
)
Linear Attention#
Efficient linear-complexity attention:
from memorax.networks import Memoroid, LinearAttentionCell
linear_attention = Memoroid(cell=LinearAttentionCell(features=64, num_heads=4, head_dim=16))
Choosing a Model#
For Short Episodes (< 100 steps)#
LSTM/GRU: Simple and effective
sLSTM: Enhanced gating
For Long Episodes (100-1000 steps)#
S5/LRU: Efficient state space models
Mamba: Selective attention to inputs
For Very Long Episodes (> 1000 steps)#
SelfAttention: With positional embeddings
LinearAttention: Linear complexity
For Memory-Intensive Tasks#
FFM/SHM: Explicit memory mechanisms
mLSTM: Matrix memory
Example: Mamba Agent#
from memorax.algorithms import PPO, PPOConfig
from memorax.networks import Network, FeatureExtractor, MLP, Memoroid, MambaCell, heads
# Mamba-based actor
actor_network = Network(
feature_extractor=FeatureExtractor(observation_extractor=MLP(features=(64,))),
torso=Memoroid(cell=MambaCell(features=64, num_heads=4, head_dim=16)),
head=heads.Categorical(action_dim=4),
)
# Mamba-based critic
critic_network = Network(
feature_extractor=FeatureExtractor(observation_extractor=MLP(features=(64,))),
torso=Memoroid(cell=MambaCell(features=64, num_heads=4, head_dim=16)),
head=heads.VNetwork(),
)
agent = PPO(cfg, env, env_params, actor_network, critic_network, optimizer, optimizer)