Source code for memorax.algorithms.ppo

from functools import partial
from typing import Any, Callable, Optional

import jax
import jax.numpy as jnp
import optax
from flax import core
from flax import linen as nn
from flax import struct

from memorax.networks.sequence_models.utils import (
    add_feature_axis,
    remove_feature_axis,
    remove_time_axis,
)
from memorax.networks.sequence_models.wrappers import SequenceModelWrapper
from memorax.utils import Timestep, Transition, generalized_advantage_estimatation
from memorax.utils.typing import Array, Discrete, Environment, EnvParams, EnvState, Key


[docs] @struct.dataclass(frozen=True) class PPOConfig: name: str num_envs: int num_eval_envs: int num_steps: int gamma: float gae_lambda: float num_minibatches: int update_epochs: int normalize_advantage: bool clip_coef: float clip_vloss: bool ent_coef: float vf_coef: float target_kl: Optional[float] = None burn_in_length: int = 0 @property def batch_size(self): return self.num_envs * self.num_steps
[docs] @struct.dataclass(frozen=True) class PPOState: step: int timestep: Timestep env_state: EnvState actor_params: core.FrozenDict[str, Any] actor_optimizer_state: optax.OptState actor_carry: Array critic_params: core.FrozenDict[str, Any] critic_optimizer_state: optax.OptState critic_carry: Array
[docs] @struct.dataclass(frozen=True) class PPO: cfg: PPOConfig env: Environment env_params: EnvParams actor: nn.Module critic: nn.Module actor_optimizer: optax.GradientTransformation critic_optimizer: optax.GradientTransformation def _deterministic_action( self, key: Key, state: PPOState ) -> tuple[Key, PPOState, Array, Array]: timestep = state.timestep.to_sequence() actor_carry, probs = self.actor.apply( state.actor_params, observation=timestep.obs, mask=timestep.done, action=timestep.action, reward=add_feature_axis(timestep.reward), done=timestep.done, initial_carry=state.actor_carry, ) action = ( jnp.argmax(probs.logits, axis=-1) if isinstance(self.env.action_space(self.env_params), Discrete) else probs.mode() ) log_prob = probs.log_prob(action) action = remove_time_axis(action) log_prob = remove_time_axis(log_prob) state = state.replace( actor_carry=actor_carry, ) return key, state, action, log_prob, None def _stochastic_action( self, key: Key, state: PPOState ) -> tuple[Key, PPOState, Array, Array]: ( key, action_key, actor_memory_key, critic_memory_key, ) = jax.random.split(key, 4) timestep = state.timestep.to_sequence() actor_carry, probs = self.actor.apply( state.actor_params, observation=timestep.obs, mask=timestep.done, action=timestep.action, reward=add_feature_axis(timestep.reward), done=timestep.done, initial_carry=state.actor_carry, rngs={"memory": actor_memory_key}, ) action, log_prob = probs.sample_and_log_prob(seed=action_key) critic_carry, value = self.critic.apply( state.critic_params, observation=timestep.obs, mask=timestep.done, action=timestep.action, reward=add_feature_axis(timestep.reward), done=timestep.done, initial_carry=state.critic_carry, rngs={"memory": critic_memory_key}, ) action = remove_time_axis(action) log_prob = remove_time_axis(log_prob) value = remove_time_axis(value) value = remove_feature_axis(value) state = state.replace( actor_carry=actor_carry, critic_carry=critic_carry, ) return key, state, action, log_prob, value def _step(self, carry: tuple, _, *, policy: Callable): key, state = carry key, action_key, step_key = jax.random.split(key, 3) key, state, action, log_prob, value = policy(action_key, state) num_envs, *_ = state.timestep.obs.shape step_key = jax.random.split(step_key, num_envs) next_obs, env_state, reward, done, info = jax.vmap( self.env.step, in_axes=(0, 0, 0, None) )(step_key, state.env_state, action, self.env_params) broadcast_dims = tuple( range(state.timestep.done.ndim, state.timestep.action.ndim) ) prev_action = jnp.where( jnp.expand_dims(state.timestep.done, axis=broadcast_dims), jnp.zeros_like(state.timestep.action), state.timestep.action, ) transition = Transition( obs=state.timestep.obs, # type: ignore action=action, # type: ignore reward=reward, # type: ignore done=done, # type: ignore info=info, # type: ignore log_prob=log_prob, # type: ignore value=value, # type: ignore prev_action=prev_action, # type: ignore prev_reward=jnp.where(state.timestep.done, 0, state.timestep.reward), # type: ignore prev_done=state.timestep.done, ) state = state.replace( step=state.step + self.cfg.num_envs, timestep=Timestep(obs=next_obs, action=action, reward=reward, done=done), env_state=env_state, ) return (key, state), transition def _update_actor( self, key, state: PPOState, initial_actor_carry, transitions, advantages ): key, memory_key, dropout_key = jax.random.split(key, 3) if self.cfg.burn_in_length > 0: burn_in = jax.tree.map( lambda x: x[:, : self.cfg.burn_in_length], transitions ) initial_actor_carry, _ = self.actor.apply( jax.lax.stop_gradient(state.actor_params), observation=burn_in.obs, mask=burn_in.prev_done, action=burn_in.prev_action, reward=add_feature_axis(burn_in.prev_reward), done=burn_in.prev_done, initial_carry=initial_actor_carry, ) initial_actor_carry = jax.lax.stop_gradient(initial_actor_carry) transitions = jax.tree.map( lambda x: x[:, self.cfg.burn_in_length :], transitions ) advantages = advantages[:, self.cfg.burn_in_length :] def actor_loss_fn(params): _, probs = self.actor.apply( params, observation=transitions.obs, mask=transitions.prev_done, action=transitions.prev_action, reward=add_feature_axis(transitions.prev_reward), done=transitions.prev_done, initial_carry=initial_actor_carry, rngs={"memory": memory_key, "dropout": dropout_key}, ) log_probs = probs.log_prob(transitions.action) entropy = probs.entropy().mean() ratio = jnp.exp(log_probs - transitions.log_prob) approx_kl = jnp.mean(transitions.log_prob - log_probs) clipfrac = jnp.mean( (jnp.abs(ratio - 1.0) > self.cfg.clip_coef).astype(jnp.float32) ) actor_loss = -jnp.minimum( ratio * advantages, jnp.clip( ratio, 1.0 - self.cfg.clip_coef, 1.0 + self.cfg.clip_coef, ) * advantages, ).mean() return actor_loss - self.cfg.ent_coef * entropy, ( entropy.mean(), # type: ignore approx_kl.mean(), # type: ignore clipfrac.mean(), # type: ignore ) (actor_loss, aux), actor_grads = jax.value_and_grad( actor_loss_fn, has_aux=True )(state.actor_params) actor_updates, actor_optimizer_state = self.actor_optimizer.update( actor_grads, state.actor_optimizer_state, state.actor_params ) actor_params = optax.apply_updates(state.actor_params, actor_updates) state = state.replace( actor_params=actor_params, actor_optimizer_state=actor_optimizer_state, ) return key, state, actor_loss.mean(), aux def _update_critic( self, key, state: PPOState, initial_critic_carry, transitions, returns ): key, memory_key, dropout_key = jax.random.split(key, 3) if self.cfg.burn_in_length > 0: burn_in = jax.tree.map( lambda x: x[:, : self.cfg.burn_in_length], transitions ) initial_critic_carry, _ = self.critic.apply( jax.lax.stop_gradient(state.critic_params), observation=burn_in.obs, mask=burn_in.prev_done, action=burn_in.prev_action, reward=add_feature_axis(burn_in.prev_reward), done=burn_in.prev_done, initial_carry=initial_critic_carry, ) initial_critic_carry = jax.lax.stop_gradient(initial_critic_carry) transitions = jax.tree.map( lambda x: x[:, self.cfg.burn_in_length :], transitions ) returns = returns[:, self.cfg.burn_in_length :] def critic_loss_fn(params): _, values = self.critic.apply( params, observation=transitions.obs, mask=transitions.prev_done, action=transitions.prev_action, reward=add_feature_axis(transitions.prev_reward), done=transitions.prev_done, initial_carry=initial_critic_carry, rngs={"memory": memory_key, "dropout": dropout_key}, ) values = remove_feature_axis(values) if self.cfg.clip_vloss: critic_loss = jnp.square(values - returns) clipped_value = transitions.value + jnp.clip( (values - transitions.value), -self.cfg.clip_coef, self.cfg.clip_coef, ) clipped_critic_loss = jnp.square(clipped_value - returns) critic_loss = 0.5 * jnp.maximum(critic_loss, clipped_critic_loss).mean() else: critic_loss = 0.5 * jnp.square(values - returns).mean() return critic_loss critic_loss, critic_grads = jax.value_and_grad(critic_loss_fn)( state.critic_params ) critic_updates, critic_optimizer_state = self.critic_optimizer.update( critic_grads, state.critic_optimizer_state, state.critic_params ) critic_params = optax.apply_updates(state.critic_params, critic_updates) state = state.replace( critic_params=critic_params, critic_optimizer_state=critic_optimizer_state ) return key, state, critic_loss.mean() def _update_minibatch(self, carry, minibatch: tuple): key, state = carry ( initial_actor_carry, initial_critic_carry, transitions, advantages, returns, ) = minibatch key, state, critic_loss = self._update_critic( key, state, initial_critic_carry, transitions, returns ) key, state, actor_loss, aux = self._update_actor( key, state, initial_actor_carry, transitions, advantages ) return (key, state), (actor_loss, critic_loss, aux) def _update_epoch(self, carry: tuple): ( key, state, initial_actor_h_epoch, initial_critic_h_epoch, transitions, advantages, returns, *_, epoch, ) = carry key, permutation_key = jax.random.split(key) batch = ( initial_actor_h_epoch, initial_critic_h_epoch, transitions, advantages, returns, ) def shuffle(batch): shuffle_time_axis = isinstance( self.actor.torso, SequenceModelWrapper ) and isinstance(self.critic.torso, SequenceModelWrapper) num_permutations = self.cfg.num_envs if shuffle_time_axis: batch = ( initial_actor_h_epoch, initial_critic_h_epoch, *jax.tree.map( lambda x: x.reshape(-1, 1, *x.shape[2:]), (transitions, advantages, returns), ), ) num_permutations *= self.cfg.num_steps permutation = jax.random.permutation(permutation_key, num_permutations) minibatches = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0).reshape( self.cfg.num_minibatches, -1, *x.shape[1:] ), tuple(batch), ) return minibatches minibatches = shuffle(batch) (key, state), (actor_loss, critic_loss, (entropy, approx_kl, clipfrac)) = ( jax.lax.scan( self._update_minibatch, (key, state), minibatches, ) ) metrics = jax.tree.map( lambda x: x.mean(), (actor_loss, critic_loss, entropy, approx_kl, clipfrac) ) return ( key, state, initial_actor_h_epoch, initial_critic_h_epoch, transitions, advantages, returns, metrics, epoch + 1, ) def _update_step(self, carry: tuple, _): key, state = carry initial_actor_h_rollout = state.actor_carry initial_critic_h_rollout = state.critic_carry (key, state), transitions = jax.lax.scan( partial(self._step, policy=self._stochastic_action), (key, state), length=self.cfg.num_steps, ) timestep = state.timestep.to_sequence() _, value = self.critic.apply( state.critic_params, observation=timestep.obs, mask=timestep.done, action=timestep.action, reward=add_feature_axis(timestep.reward), done=timestep.done, initial_carry=state.critic_carry, ) value = remove_time_axis(value) value = remove_feature_axis(value) advantages, returns = generalized_advantage_estimatation( self.cfg.gamma, self.cfg.gae_lambda, value, transitions, ) transitions = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), transitions) advantages = jnp.swapaxes(advantages, 0, 1) returns = jnp.swapaxes(returns, 0, 1) if self.cfg.normalize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) def cond_fun(carry): *_, (*_, approx_kl, _), epoch = carry cond = epoch < self.cfg.update_epochs if self.cfg.target_kl: cond = cond & (approx_kl < self.cfg.target_kl) return cond (key, state, *_, metrics, _) = jax.lax.while_loop( cond_fun, self._update_epoch, ( key, state, initial_actor_h_rollout, initial_critic_h_rollout, transitions, advantages, returns, (0.0, 0.0, 0.0, 0.0, 0.0), 0, ), ) actor_loss, critic_loss, entropy, approx_kl, clipfrac = jax.tree.map( lambda x: jnp.expand_dims(x, axis=(0, 1)), metrics ) info = { **transitions.info, "losses/actor_loss": actor_loss, "losses/critic_loss": critic_loss, "losses/entropy": entropy, "losses/approx_kl": approx_kl, "losses/clipfrac": clipfrac, } return ( key, state, ), transitions.replace(obs=None, next_obs=None, info=info)
[docs] @partial(jax.jit, static_argnames=["self"]) def init(self, key): ( key, env_key, actor_key, actor_memory_key, actor_dropout_key, critic_key, critic_memory_key, critic_dropout_key, ) = jax.random.split(key, 8) env_keys = jax.random.split(env_key, self.cfg.num_envs) obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))( env_keys, self.env_params ) action = jnp.zeros( (self.cfg.num_envs, *self.env.action_space(self.env_params).shape), dtype=self.env.action_space(self.env_params).dtype, ) reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32) done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool) timestep = Timestep( obs=obs, action=action, reward=reward, done=done ).to_sequence() actor_carry = self.actor.initialize_carry((self.cfg.num_envs, None)) critic_carry = self.critic.initialize_carry((self.cfg.num_envs, None)) actor_params = self.actor.init( { "params": actor_key, "memory": actor_memory_key, "dropout": actor_dropout_key, }, observation=timestep.obs, mask=timestep.done, action=timestep.action, reward=add_feature_axis(timestep.reward), done=timestep.done, initial_carry=actor_carry, ) critic_params = self.critic.init( { "params": critic_key, "memory": critic_memory_key, "dropout": critic_dropout_key, }, observation=timestep.obs, mask=timestep.done, action=timestep.action, reward=add_feature_axis(timestep.reward), done=timestep.done, initial_carry=critic_carry, ) actor_optimizer_state = self.actor_optimizer.init(actor_params) critic_optimizer_state = self.critic_optimizer.init(critic_params) return ( key, PPOState( step=0, # type: ignore timestep=timestep.from_sequence(), actor_carry=actor_carry, # type: ignore critic_carry=critic_carry, # type: ignore env_state=env_state, # type: ignore actor_params=actor_params, # type: ignore critic_params=critic_params, # type: ignore actor_optimizer_state=actor_optimizer_state, # type: ignore critic_optimizer_state=critic_optimizer_state, # type: ignore ), )
[docs] @partial(jax.jit, static_argnames=["self", "num_steps"]) def warmup(self, key, state, num_steps): """No warmup needed for PPO""" return key, state
[docs] @partial(jax.jit, static_argnums=(0, 3)) def train(self, key, state, num_steps): (key, state), transitions = jax.lax.scan( self._update_step, (key, state), length=num_steps // (self.cfg.num_envs * self.cfg.num_steps), ) transitions = jax.tree.map( lambda x: x.swapaxes(1, 2).reshape((-1,) + x.shape[2:]), transitions ) return key, state, transitions
[docs] @partial(jax.jit, static_argnames=["self", "num_steps", "deterministic"]) def evaluate(self, key, state, num_steps, deterministic=True): key, reset_key = jax.random.split(key) reset_key = jax.random.split(reset_key, self.cfg.num_eval_envs) obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))( reset_key, self.env_params ) action = jnp.zeros( (self.cfg.num_eval_envs, *self.env.action_space(self.env_params).shape), dtype=self.env.action_space(self.env_params).dtype, ) reward = jnp.zeros(self.cfg.num_eval_envs, dtype=jnp.float32) done = jnp.ones(self.cfg.num_eval_envs, dtype=jnp.bool) timestep = Timestep(obs=obs, action=action, reward=reward, done=done) initial_actor_carry = self.actor.initialize_carry( (self.cfg.num_eval_envs, None) ) initial_critic_carry = self.critic.initialize_carry( (self.cfg.num_eval_envs, None) ) state = state.replace( timestep=timestep, actor_carry=initial_actor_carry, critic_carry=initial_critic_carry, env_state=env_state, ) policy = ( self._deterministic_action if deterministic else self._stochastic_action ) (key, *_), transitions = jax.lax.scan( partial(self._step, policy=policy), (key, state), length=num_steps, ) return key, transitions.replace(obs=None, next_obs=None)