Working with Algorithms#
Memorax provides several RL algorithms optimized for memory-augmented learning.
Available Algorithms#
Algorithm |
Action Space |
Use Case |
|---|---|---|
PPO |
Discrete & Continuous |
General-purpose, stable training |
DQN |
Discrete |
Value-based learning |
SAC |
Continuous |
Maximum entropy RL |
PQN |
Discrete |
On-policy Q-learning |
PPO (Proximal Policy Optimization)#
Best for general-purpose training with memory architectures.
from memorax.algorithms import PPO, PPOConfig
cfg = PPOConfig(
name="PPO-Experiment",
num_envs=8,
num_eval_envs=16,
num_steps=128,
gamma=0.99,
gae_lambda=0.95,
num_minibatches=4,
update_epochs=4,
normalize_advantage=True,
clip_coef=0.2,
clip_vloss=True,
ent_coef=0.01,
vf_coef=0.5,
burn_in_length=0, # RNN burn-in steps
)
agent = PPO(cfg, env, env_params, actor, critic, actor_optimizer, critic_optimizer)
Key Parameters#
num_envs: Number of parallel environments for trainingnum_steps: Steps per rollout before updateclip_coef: PPO clipping coefficient (0.1-0.3)burn_in_length: Steps to “warm up” RNN hidden state
DQN (Deep Q-Network)#
For discrete action spaces with value-based learning.
from memorax.algorithms import DQN, DQNConfig
cfg = DQNConfig(
name="DQN-Experiment",
num_envs=8,
buffer_size=100_000,
batch_size=32,
learning_starts=1000,
target_update_freq=1000,
gamma=0.99,
epsilon_start=1.0,
epsilon_end=0.05,
epsilon_decay_steps=50_000,
)
agent = DQN(cfg, env, env_params, q_network, optimizer)
SAC (Soft Actor-Critic)#
For continuous control with entropy regularization.
from memorax.algorithms import SAC, SACConfig
cfg = SACConfig(
name="SAC-Experiment",
num_envs=8,
buffer_size=1_000_000,
batch_size=256,
learning_starts=10_000,
gamma=0.99,
tau=0.005,
alpha=0.2, # Temperature parameter
auto_alpha=True, # Learn temperature
)
agent = SAC(cfg, env, env_params, actor, critic, critic, actor_optimizer, critic_optimizer, alpha_optimizer)
Training Loop Pattern#
All algorithms follow the same interface:
# Initialize
key, state = agent.init(key)
# Optional: warmup (fill replay buffer for off-policy)
key, state = agent.warmup(key, state, num_steps=10_000)
# Train
key, state, transitions = agent.train(key, state, num_steps=100_000)
# Evaluate
key, returns = agent.evaluate(key, state, num_episodes=10)
Burn-in for Recurrent Networks#
When using RNNs/SSMs, use burn-in to establish hidden state context:
cfg = PPOConfig(
burn_in_length=20, # 20 steps of context
num_steps=128,
)
This replays the first burn_in_length steps without gradient to initialize the hidden state before computing losses.