Working with Algorithms#

Memorax provides several RL algorithms optimized for memory-augmented learning.

Available Algorithms#

Algorithm

Action Space

Use Case

PPO

Discrete & Continuous

General-purpose, stable training

DQN

Discrete

Value-based learning

SAC

Continuous

Maximum entropy RL

PQN

Discrete

On-policy Q-learning

PPO (Proximal Policy Optimization)#

Best for general-purpose training with memory architectures.

from memorax.algorithms import PPO, PPOConfig

cfg = PPOConfig(
    name="PPO-Experiment",
    num_envs=8,
    num_eval_envs=16,
    num_steps=128,
    gamma=0.99,
    gae_lambda=0.95,
    num_minibatches=4,
    update_epochs=4,
    normalize_advantage=True,
    clip_coef=0.2,
    clip_vloss=True,
    ent_coef=0.01,
    vf_coef=0.5,
    burn_in_length=0,  # RNN burn-in steps
)

agent = PPO(cfg, env, env_params, actor, critic, actor_optimizer, critic_optimizer)

Key Parameters#

  • num_envs: Number of parallel environments for training

  • num_steps: Steps per rollout before update

  • clip_coef: PPO clipping coefficient (0.1-0.3)

  • burn_in_length: Steps to “warm up” RNN hidden state

DQN (Deep Q-Network)#

For discrete action spaces with value-based learning.

from memorax.algorithms import DQN, DQNConfig

cfg = DQNConfig(
    name="DQN-Experiment",
    num_envs=8,
    buffer_size=100_000,
    batch_size=32,
    learning_starts=1000,
    target_update_freq=1000,
    gamma=0.99,
    epsilon_start=1.0,
    epsilon_end=0.05,
    epsilon_decay_steps=50_000,
)

agent = DQN(cfg, env, env_params, q_network, optimizer)

SAC (Soft Actor-Critic)#

For continuous control with entropy regularization.

from memorax.algorithms import SAC, SACConfig

cfg = SACConfig(
    name="SAC-Experiment",
    num_envs=8,
    buffer_size=1_000_000,
    batch_size=256,
    learning_starts=10_000,
    gamma=0.99,
    tau=0.005,
    alpha=0.2,  # Temperature parameter
    auto_alpha=True,  # Learn temperature
)

agent = SAC(cfg, env, env_params, actor, critic, critic, actor_optimizer, critic_optimizer, alpha_optimizer)

Training Loop Pattern#

All algorithms follow the same interface:

# Initialize
key, state = agent.init(key)

# Optional: warmup (fill replay buffer for off-policy)
key, state = agent.warmup(key, state, num_steps=10_000)

# Train
key, state, transitions = agent.train(key, state, num_steps=100_000)

# Evaluate
key, returns = agent.evaluate(key, state, num_episodes=10)

Burn-in for Recurrent Networks#

When using RNNs/SSMs, use burn-in to establish hidden state context:

cfg = PPOConfig(
    burn_in_length=20,  # 20 steps of context
    num_steps=128,
)

This replays the first burn_in_length steps without gradient to initialize the hidden state before computing losses.