Memory-Augmented Reinforcement Learning in JAX

Contents

Memory-Augmented Reinforcement Learning in JAX#

A unified JAX/Flax framework for memory-augmented reinforcement learning.

Memorax provides modular, high-performance implementations of RL algorithms with support for advanced sequence models including RNNs, State Space Models, and Transformers.

Features#

  • 🧠 Algorithms: PPO, DQN, SAC, PQN with full JAX vectorization

  • 🔁 Sequence Models: LSTM, GRU, Mamba, S5, LRU, Linear Attention, and more

  • 🌍 Environments: Integration with Gymnax, Brax, POPGym, Craftax, and others

  • 📊 Logging: Weights & Biases, TensorBoard, Neptune, and console logging