# Working with Algorithms

Memorax provides several RL algorithms optimized for memory-augmented learning.

## Available Algorithms

| Algorithm | Action Space | Use Case |
|-----------|--------------|----------|
| PPO | Discrete & Continuous | General-purpose, stable training |
| DQN | Discrete | Value-based learning |
| SAC | Continuous | Maximum entropy RL |
| PQN | Discrete | On-policy Q-learning |

## PPO (Proximal Policy Optimization)

Best for general-purpose training with memory architectures.

```python
from memorax.algorithms import PPO, PPOConfig

cfg = PPOConfig(
    name="PPO-Experiment",
    num_envs=8,
    num_eval_envs=16,
    num_steps=128,
    gamma=0.99,
    gae_lambda=0.95,
    num_minibatches=4,
    update_epochs=4,
    normalize_advantage=True,
    clip_coef=0.2,
    clip_vloss=True,
    ent_coef=0.01,
    vf_coef=0.5,
    burn_in_length=0,  # RNN burn-in steps
)

agent = PPO(cfg, env, env_params, actor, critic, actor_optimizer, critic_optimizer)
```

### Key Parameters

- `num_envs`: Number of parallel environments for training
- `num_steps`: Steps per rollout before update
- `clip_coef`: PPO clipping coefficient (0.1-0.3)
- `burn_in_length`: Steps to "warm up" RNN hidden state

## DQN (Deep Q-Network)

For discrete action spaces with value-based learning.

```python
from memorax.algorithms import DQN, DQNConfig

cfg = DQNConfig(
    name="DQN-Experiment",
    num_envs=8,
    buffer_size=100_000,
    batch_size=32,
    learning_starts=1000,
    target_update_freq=1000,
    gamma=0.99,
    epsilon_start=1.0,
    epsilon_end=0.05,
    epsilon_decay_steps=50_000,
)

agent = DQN(cfg, env, env_params, q_network, optimizer)
```

## SAC (Soft Actor-Critic)

For continuous control with entropy regularization.

```python
from memorax.algorithms import SAC, SACConfig

cfg = SACConfig(
    name="SAC-Experiment",
    num_envs=8,
    buffer_size=1_000_000,
    batch_size=256,
    learning_starts=10_000,
    gamma=0.99,
    tau=0.005,
    alpha=0.2,  # Temperature parameter
    auto_alpha=True,  # Learn temperature
)

agent = SAC(cfg, env, env_params, actor, critic, critic, actor_optimizer, critic_optimizer, alpha_optimizer)
```

## Training Loop Pattern

All algorithms follow the same interface:

```python
# Initialize
key, state = agent.init(key)

# Optional: warmup (fill replay buffer for off-policy)
key, state = agent.warmup(key, state, num_steps=10_000)

# Train
key, state, transitions = agent.train(key, state, num_steps=100_000)

# Evaluate
key, returns = agent.evaluate(key, state, num_episodes=10)
```

## Burn-in for Recurrent Networks

When using RNNs/SSMs, use burn-in to establish hidden state context:

```python
cfg = PPOConfig(
    burn_in_length=20,  # 20 steps of context
    num_steps=128,
)
```

This replays the first `burn_in_length` steps without gradient to initialize the hidden state before computing losses.
