Quick Start

Contents

Quick Start#

This guide demonstrates how to train a PPO agent with GRU memory on CartPole.

from dataclasses import asdict

import flax.linen as nn
import jax
import optax

from memorax.algorithms import PPO, PPOConfig
from memorax.environments import environment
from memorax.loggers import DashboardLogger, Logger
from memorax.networks import MLP, FeatureExtractor, Network, RNN, heads

# Create environment
env, env_params = environment.make("gymnax::CartPole-v1")

# Configure PPO
cfg = PPOConfig(
    name="PPO-GRU",
    num_envs=8,
    num_eval_envs=16,
    num_steps=128,
    gamma=0.99,
    gae_lambda=0.95,
    num_minibatches=4,
    update_epochs=4,
    normalize_advantage=True,
    clip_coef=0.2,
    clip_vloss=True,
    ent_coef=0.01,
    vf_coef=0.5,
)

# Build networks
d_model = 64

feature_extractor = FeatureExtractor(
    observation_extractor=MLP(
        features=(d_model,), kernel_init=nn.initializers.orthogonal(scale=1.414)
    ),
)
torso = RNN(cell=nn.GRUCell(features=d_model))
actor_network = Network(
    feature_extractor=feature_extractor,
    torso=torso,
    head=heads.Categorical(
        action_dim=env.action_space(env_params).n,
        kernel_init=nn.initializers.orthogonal(scale=0.01),
    ),
)
critic_network = Network(
    feature_extractor=feature_extractor,
    torso=torso,
    head=heads.VNetwork(kernel_init=nn.initializers.orthogonal(scale=1.0)),
)

# Create optimizer
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(learning_rate=3e-4, eps=1e-5),
)

# Initialize agent and logger
agent = PPO(
    cfg=cfg,
    env=env,
    env_params=env_params,
    actor=actor_network,
    critic=critic_network,
    actor_optimizer=optimizer,
    critic_optimizer=optimizer,
)
logger = Logger([DashboardLogger(title="PPO-GRU CartPole", total_timesteps=500_000)])
logger_state = logger.init(cfg=asdict(cfg))

# Train
key = jax.random.key(0)
key, state = agent.init(key)

for i in range(0, 500_000, 10_000):
    key, state, transitions = agent.train(key, state, num_steps=10_000)
    training_statistics = Logger.get_episode_statistics(transitions, "training")
    logger_state = logger.log(logger_state, training_statistics, step=state.step.item())
    logger.emit(logger_state)

logger.finish(logger_state)

Next Steps#