Most JAX RL libraries treat memory as an afterthought, bolting an LSTM onto an existing agent and calling it done. Memorax makes memory a first-class citizen. It provides a composable set of sequence model primitives (attention, SSMs, linear RNNs, and more) that snap together into full architectures like GTrXL or xLSTM, paired with algorithms and replay buffers designed from the ground up for recurrent training.
Features#
Details |
|
|---|---|
Algorithms |
|
Sequence Models |
LSTM, GRU, xLSTM, FFM, SHM, S5, LRU, Mamba, MinGRU, RTU, Self-Attention, Linear Attention. Compose into GTrXL, GPT-2, and more. Support for RTRL |
Networks |
ViT encoder. RoPE and ALiBi positional embeddings. MoE for horizontal scaling. RL² wrapper for meta-RL. GVF/Horde heads. C51 and HL-Gauss distributional value heads. Composable |
Environments |
Gymnax, PopJym, PopGym Arcade, Navix, Craftax, Brax, MuJoCo, gxm, Grimax, POBAX, XMiniGrid, JaxMARL |
Buffers |
Pure JAX episode replay with prioritized sampling via Flashbax |
Logging |
CLI Dashboard, File, W&B, TensorboardX, Neptune |
Installation#
pip install memorax
With CUDA support:
pip install "memorax[cuda]"
See the Installation guide for more options.
Quick Start#
import flax.linen as nn
import jax
import optax
from memorax.algorithms import PPO, PPOConfig
from memorax.environments import make
from memorax.networks import (
FFN, ALiBi, FeatureExtractor, GatedResidual, Network,
PreNorm, SegmentRecurrence, SelfAttention, SelfAttentionConfig, Stack, heads,
)
env, env_params = make("gymnax::CartPole-v1")
config = PPOConfig(
num_envs=8,
num_steps=128,
gae_lambda=0.95,
num_minibatches=4,
update_epochs=4,
normalize_advantage=True,
clip_coefficient=0.2,
clip_value_loss=True,
entropy_coefficient=0.01,
)
features, num_heads, num_layers = 64, 4, 2
feature_extractor = FeatureExtractor(observation_extractor=nn.Sequential((nn.Dense(features), nn.relu)))
attention = GatedResidual(PreNorm(SegmentRecurrence(
SelfAttention(config=SelfAttentionConfig(features=features, num_heads=num_heads, context_length=128, positional_embedding=ALiBi(num_heads))),
memory_length=64, features=features,
)))
ffn = GatedResidual(PreNorm(FFN(features=features, expansion_factor=4)))
torso = Stack(blocks=(attention, ffn) * num_layers)
actor = Network(feature_extractor, torso, heads.Categorical(env.action_space(env_params).n))
critic = Network(feature_extractor, torso, heads.VNetwork())
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(3e-4))
agent = PPO(config, env, env_params, actor, critic, optimizer, optimizer)
key = jax.random.key(0)
key, init_key = jax.random.split(key)
state = agent.init(init_key)
key, train_key = jax.random.split(key)
state = agent.train(train_key, state, num_steps=10_000)
See the Quick Start for a complete walkthrough.
Citation#
@software{memorax2025github,
title = {Memorax: A Unified Framework for Memory-Augmented Reinforcement Learning},
author = {Noah Farr},
year = {2025},
url = {https://github.com/noahfarr/memorax}
}