Core Concepts#

Architecture Overview#

Memorax follows a modular architecture where agents are composed of an algorithm, networks, and an environment.

Agent
Algorithm
PPO, DQN, SAC, R2D2, ...
Config State
Network
Feature Extractor ↓ Torso (Sequence Model) ↓ Head
Environment
Gymnax, Brax, Navix, ...
env env_params

Algorithms#

Every algorithm in Memorax follows the same pattern: a Config (frozen dataclass of hyperparameters), a State (all mutable training state: parameters, optimizer state, hidden carries), and the Algorithm class itself which provides init(), train(), warmup(), and evaluate(). You create an agent by passing the config, environment, networks, and optimizers, then call init to get the initial state and train to run updates:

from memorax.algorithms import PPO, PPOConfig

config = PPOConfig(
    num_envs=8, num_steps=128, gae_lambda=0.95, num_minibatches=4,
    update_epochs=4, normalize_advantage=True, clip_coefficient=0.2,
    clip_value_loss=True, entropy_coefficient=0.01,
)

agent = PPO(config, env, env_params, actor, critic, optimizer, optimizer)
key, init_key = jax.random.split(key)
state = agent.init(init_key)
key, train_key = jax.random.split(key)
state = agent.train(train_key, state, num_steps=1000)

All algorithms share this interface, so switching from PPO to DQN or SAC is just a matter of changing the config and providing the right networks.

Networks#

Networks follow the FeatureExtractor → Torso → Head pipeline shown above. The feature extractor processes raw observations into a feature vector. Any Flax module works here. The torso handles temporal modeling: RNNs and Memoroid cells work directly as sequence models, while non-recurrent modules need a SequenceModelWrapper. The head produces the algorithm’s output: action distributions for policy methods, Q-values for value methods, or state values for critics:

import flax.linen as nn
from memorax.networks import FeatureExtractor, Network, RNN, Memoroid, Mamba2Cell, Mamba2Config, heads

feature_extractor = FeatureExtractor(
    observation_extractor=nn.Sequential((nn.Dense(64), nn.relu)),
)

actor = Network(feature_extractor, torso=RNN(cell=nn.GRUCell(features=64)), head=heads.Categorical(action_dim=4))
critic = Network(feature_extractor, torso=Memoroid(cell=Mamba2Cell(config=Mamba2Config(features=64))), head=heads.VNetwork())

See the Building Networks guide for the full list of available torsos, heads, and composable blocks.

JAX Patterns#

Memorax is built on JAX, which brings a few patterns that differ from typical PyTorch workflows. Vectorized environments: all training runs num_envs environments in parallel, fully JIT-compiled. Explicit random state: JAX uses functional random keys. Callers split and pass keys to each method call; methods consume their key without returning one. JIT compilation: the first call to train may be slow as JAX traces and compiles the computation graph, but subsequent calls are fast.

key = jax.random.key(0)
key, init_key = jax.random.split(key)
state = agent.init(init_key)
key, train_key = jax.random.split(key)
state = agent.train(train_key, state, num_steps=1000)

Logging#

Memorax doesn’t return metrics from train() directly. Instead, wrap your training function with lox.spool to capture logged values, and wrap the environment with RecordEpisodeStatistics to track episode returns and lengths. Use jax.vmap to vectorize across random seeds:

import lox
from memorax.environments.wrappers import RecordEpisodeStatistics

env = RecordEpisodeStatistics(env)
train = jax.vmap(lox.spool(agent.train), in_axes=(0, 0, None))
key, train_key = jax.random.split(key)
state, logs = train(jax.random.split(train_key, num_seeds), state, num_steps)

Memorax provides several logger backends: DashboardLogger for a live terminal UI, WandbLogger, TensorBoardLogger, FileLogger, and CheckpointLogger. Combine multiple backends with the MultiLogger wrapper.

Buffers#

Off-policy algorithms like DQN and R2D2 need replay buffers. Memorax provides episode-aware buffers built on Flashbax that sample complete sequences while respecting episode boundaries. The basic make_episode_buffer handles uniform sampling, while make_prioritised_episode_buffer adds TD-error-weighted sampling with importance correction. This is what R2D2 uses for prioritized experience replay:

from memorax.buffers import make_episode_buffer, make_prioritised_episode_buffer

buffer = make_episode_buffer(
    max_length=100_000, min_length=1000,
    sample_batch_size=32, sample_sequence_length=16, add_batch_size=8,
)

prioritised_buffer = make_prioritised_episode_buffer(
    max_length=100_000, min_length=1000,
    sample_batch_size=32, sample_sequence_length=16, add_batch_size=8,
    priority_exponent=0.6,
)