Core Concepts#
This page explains the key concepts and design patterns in Memorax.
Architecture Overview#
Memorax follows a modular architecture where agents are composed of:
Agent = Algorithm + Network + Environment
Network = FeatureExtractor -> Torso -> Head
Algorithms#
Each algorithm in Memorax consists of three components:
Config: A frozen dataclass containing hyperparameters
State: A dataclass holding the training state (parameters, optimizer state, etc.)
Algorithm: The main class implementing
init(),train(),warmup(), andevaluate()
from memorax.algorithms import PPO, PPOConfig, PPOState
# Configure
cfg = PPOConfig(num_envs=8, gamma=0.99)
# Create algorithm
agent = PPO(cfg, env, env_params, actor, critic, optimizer, optimizer)
# Initialize state
key, state = agent.init(key)
# Train
key, state, transitions = agent.train(key, state, num_steps=1000)
Networks#
Networks in Memorax are composed of three parts:
Feature Extractor#
Extracts features from observations, actions, rewards, and done flags:
from memorax.networks import FeatureExtractor, MLP
feature_extractor = FeatureExtractor(
observation_extractor=MLP(features=(64, 64)),
action_extractor=None, # Optional
reward_extractor=None, # Optional
done_extractor=None, # Optional
)
Torso (Sequence Model)#
Processes temporal sequences using RNNs, SSMs, or attention:
import flax.linen as nn
from memorax.networks import SequenceModelWrapper, RNN, MLP, Memoroid, MambaCell
# GRU torso (recurrent models don't need wrapper)
torso = RNN(cell=nn.GRUCell(features=64))
# Mamba torso (MambaCell is used inside Memoroid)
torso = Memoroid(cell=MambaCell(features=64))
# MLP torso (non-recurrent models need SequenceModelWrapper)
torso = SequenceModelWrapper(MLP(features=(64,)))
Head#
Produces outputs for the RL objective:
from memorax.networks import heads
# Discrete policy
head = heads.Categorical(action_dim=4)
# Continuous policy
head = heads.SquashedGaussian(action_dim=2)
# Value function
head = heads.VNetwork()
# Q-network
head = heads.DiscreteQNetwork(action_dim=4)
JAX Patterns#
Memorax leverages JAX for performance:
Vectorized Environments#
All training runs multiple environments in parallel:
cfg = PPOConfig(num_envs=8) # 8 parallel environments
Random Keys#
JAX uses explicit random state management:
key = jax.random.key(0)
key, state = agent.init(key)
key, state, transitions = agent.train(key, state, num_steps=1000)
JIT Compilation#
Training loops are JIT-compiled for performance. The first call may be slow due to compilation.
Transitions#
Training produces Transition objects containing:
observation: Environment observationsaction: Actions takenreward: Rewards receiveddone: Episode termination flagsinfo: Additional information (returns, etc.)