Quick Start#
This guide walks through training a PPO agent with GRU memory on CartPole.
Environment#
Create a Gymnax environment wrapped with episode statistics tracking:
from memorax.environments import make
from memorax.environments.wrappers import RecordEpisodeStatistics
env, env_params = make("gymnax::CartPole-v1")
env = RecordEpisodeStatistics(env)
Configuration#
Define the PPO hyperparameters:
from memorax.algorithms import PPOConfig
config = PPOConfig(
num_envs=8,
num_steps=128,
gae_lambda=0.95,
num_minibatches=4,
update_epochs=4,
normalize_advantage=True,
clip_coefficient=0.2,
clip_value_loss=True,
entropy_coefficient=0.01,
)
Networks#
Memorax networks follow a feature_extractor → torso → head pipeline. The feature extractor processes raw observations, the torso handles temporal sequence modeling, and the head produces the final output (action distribution or value estimate).
import flax.linen as nn
from memorax.networks import FeatureExtractor, Network, RNN, heads
d_model = 64
feature_extractor = FeatureExtractor(
observation_extractor=nn.Sequential((nn.Dense(d_model), nn.relu)),
)
torso = RNN(cell=nn.GRUCell(features=d_model))
actor_network = Network(
feature_extractor=feature_extractor,
torso=torso,
head=heads.Categorical(
action_dim=env.action_space(env_params).n,
kernel_init=nn.initializers.orthogonal(scale=0.01),
),
)
critic_network = Network(
feature_extractor=feature_extractor,
torso=torso,
head=heads.VNetwork(kernel_init=nn.initializers.orthogonal(scale=1.0)),
)
Agent#
Combine the config, environment, networks, and optimizer into a PPO agent:
import optax
from memorax.algorithms import PPO
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate=3e-4, eps=1e-5),
)
agent = PPO(
config=config,
env=env,
env_params=env_params,
actor_network=actor_network,
critic_network=critic_network,
actor_optimizer=optimizer,
critic_optimizer=optimizer,
)
Training#
Use jax.vmap to vectorize across seeds and lox.spool to capture training metrics. The logger displays a live dashboard in the terminal:
from dataclasses import asdict
import jax
import lox
from memorax.loggers import DashboardLogger, MultiLogger
logger = MultiLogger([DashboardLogger(total_timesteps=500_000, summary={"Algorithm": "PPO-GRU", "Environment": "CartPole"})])
init = jax.vmap(agent.init)
train = jax.vmap(lox.spool(agent.train), in_axes=(0, 0, None))
key = jax.random.key(0)
num_seeds = 1
key, init_key = jax.random.split(key)
state = init(jax.random.split(init_key, num_seeds))
for i in range(0, 500_000, 10_000):
key, train_key = jax.random.split(key)
state, logs = train(jax.random.split(train_key, num_seeds), state, 10_000)
info = logs.pop("info")
episode_returns = info["returned_episode_returns"][info["returned_episode"]]
episode_lengths = info["returned_episode_lengths"][info["returned_episode"]]
data = {
"training/episode_returns": episode_returns,
"training/episode_lengths": episode_lengths,
**logs,
}
logger.log(data, step=state.step.mean().item())
logger.finish()
Next Steps#
Learn about different Working with Algorithms
Explore available Sequence Models
Build custom Building Networks