Source code for memorax.algorithms.mappo

from dataclasses import dataclass
from functools import partial
from typing import Any, Callable

import flax.linen as nn
import jax
import jax.numpy as jnp
import lox
import optax
from flax import core, struct

from memorax.utils import Timestep, Transition
from memorax.utils.axes import add_time_axis, remove_feature_axis, remove_time_axis
from memorax.utils.typing import (
    Array,
    Carry,
    Environment,
    EnvParams,
    EnvState,
    Key,
    PyTree,
)

to_sequence = lambda timestep: jax.tree.map(
    lambda x: jax.vmap(add_time_axis)(x), timestep
)

from_sequence = lambda timestep: jax.tree.map(
    lambda x: jax.vmap(remove_time_axis)(x), timestep
)


[docs] @struct.dataclass(frozen=True) class MAPPOConfig: num_envs: int num_steps: int gamma: float gae_lambda: float num_minibatches: int update_epochs: int normalize_advantage: bool clip_coefficient: float clip_value_loss: bool entropy_coefficient: float burn_in_length: int = 0 @property def batch_size(self): return self.num_envs * self.num_steps
[docs] @struct.dataclass(frozen=True) class MAPPOState: step: int update_step: int timestep: Timestep env_state: EnvState actor_params: core.FrozenDict[str, Any] actor_optimizer_state: optax.OptState actor_carry: Array critic_params: core.FrozenDict[str, Any] critic_optimizer_state: optax.OptState critic_carry: Array
[docs] @dataclass class MAPPO: cfg: MAPPOConfig env: Environment env_params: EnvParams actor_network: nn.Module critic_network: nn.Module actor_optimizer: optax.GradientTransformation critic_optimizer: optax.GradientTransformation def __post_init__(self): assert ( self.cfg.update_epochs >= 1 ), f"update_epochs ({self.cfg.update_epochs}) must be >= 1" assert ( self.cfg.batch_size % self.cfg.num_minibatches == 0 ), f"num_envs * num_steps ({self.cfg.batch_size}) must be divisible by num_minibatches ({self.cfg.num_minibatches})" def _deterministic_action( self, key: Key, state: MAPPOState ) -> tuple[MAPPOState, Array, Array, None, dict]: ts = to_sequence(state.timestep) (actor_carry, (probs, _)), intermediates = self.actor_network.apply( state.actor_params, *ts, ts.done, state.actor_carry, mutable=["intermediates"], ) action = jnp.argmax(probs.logits, axis=-1) log_prob = probs.log_prob(action) action = jax.vmap(remove_time_axis)(action) log_prob = jax.vmap(remove_time_axis)(log_prob) state = state.replace(actor_carry=actor_carry) return state, action, log_prob, None, intermediates def _stochastic_action( self, key: Key, state: MAPPOState ) -> tuple[MAPPOState, Array, Array, Array, dict]: action_key, actor_torso_key, critic_torso_key = jax.random.split(key, 3) ts = to_sequence(state.timestep) (actor_carry, (probs, _)), intermediates = self.actor_network.apply( state.actor_params, *ts, ts.done, state.actor_carry, rngs={"torso": actor_torso_key}, mutable=["intermediates"], ) action_keys = jax.random.split(action_key, self.env.num_agents) sampled_action, log_prob = jax.vmap(lambda p, k: p.sample_and_log_prob(seed=k))( probs, action_keys ) critic_carry, (value, _) = self.critic_network.apply( state.critic_params, *ts, ts.done, state.critic_carry, rngs={"torso": critic_torso_key}, ) action = jax.vmap(remove_time_axis)(sampled_action) log_prob = jax.vmap(remove_time_axis)(log_prob) value = jax.vmap(remove_time_axis)(value) value = remove_feature_axis(value) state = state.replace(actor_carry=actor_carry, critic_carry=critic_carry) return state, action, log_prob, value, intermediates def _generalized_advantage_estimation(self, carry: tuple, transition: Transition): advantage, next_value = carry delta = ( transition.second.reward + self.cfg.gamma * (1 - transition.second.done) * next_value - transition.aux["value"] ) advantage = ( delta + self.cfg.gamma * self.cfg.gae_lambda * (1 - transition.second.done) * advantage ) return (advantage, transition.aux["value"]), advantage def _step( self, state: MAPPOState, key: Key, *, policy: Callable ) -> tuple[MAPPOState, Transition]: action_key, step_key = jax.random.split(key) state, action, log_prob, value, intermediates = policy(action_key, state) _, num_envs, *_ = state.timestep.obs.shape step_keys = jax.random.split(step_key, num_envs) next_obs, env_state, reward, done, info = jax.vmap( self.env.step, in_axes=(0, 0, 1), out_axes=(1, 0, 1, 1, 0) )(step_keys, state.env_state, action) intermediates = jax.tree.map( lambda x: jnp.mean(jnp.stack(x)), intermediates.get("intermediates", {}), is_leaf=lambda x: isinstance(x, tuple), ) broadcast_dims = tuple( range(state.timestep.done.ndim, state.timestep.action.ndim) ) first = Timestep( obs=state.timestep.obs, action=state.timestep.action, reward=state.timestep.reward, done=state.timestep.done, ) second = Timestep( obs=next_obs, action=action, reward=reward, done=done, ) lox.log({"info": info, "intermediates": intermediates}) transition = Transition( first=first, second=second, aux={"log_prob": log_prob, "value": value}, ) state = state.replace( step=state.step + num_envs, timestep=Timestep( obs=next_obs, action=jnp.where( jnp.expand_dims(done, axis=broadcast_dims), jnp.zeros_like(action), action, ), reward=jnp.where( done, 0, jnp.asarray(reward, dtype=jnp.float32) ), done=done, ), env_state=env_state, ) return state, transition def _update_actor( self, key: Key, state: MAPPOState, initial_actor_carry: Carry, transitions: Transition, ) -> tuple[MAPPOState, Array, tuple[Array, Array, Array]]: torso_key, dropout_key = jax.random.split(key) if self.cfg.burn_in_length > 0: burn_in = jax.tree.map( lambda x: x[:, :, : self.cfg.burn_in_length], transitions ) initial_actor_carry, (_, _) = self.actor_network.apply( jax.lax.stop_gradient(state.actor_params), *burn_in.first, burn_in.first.done, initial_actor_carry, ) initial_actor_carry = jax.lax.stop_gradient(initial_actor_carry) transitions = jax.tree.map( lambda x: x[:, :, self.cfg.burn_in_length :], transitions ) advantages = transitions.aux["advantages"] if self.cfg.normalize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) def actor_loss_fn(params: PyTree): _, (probs, _) = self.actor_network.apply( params, *transitions.first, transitions.first.done, initial_actor_carry, rngs={"torso": torso_key, "dropout": dropout_key}, ) log_probs = probs.log_prob(transitions.second.action) entropy = probs.entropy().mean() ratio = jnp.exp(log_probs - transitions.aux["log_prob"]) approximate_kl = jnp.mean(transitions.aux["log_prob"] - log_probs) clip_fraction = jnp.mean( (jnp.abs(ratio - 1.0) > self.cfg.clip_coefficient).astype(jnp.float32) ) actor_loss = -jnp.minimum( ratio * advantages, jnp.clip( ratio, 1.0 - self.cfg.clip_coefficient, 1.0 + self.cfg.clip_coefficient, ) * advantages, ).mean() return actor_loss - self.cfg.entropy_coefficient * entropy, ( entropy.mean(), approximate_kl.mean(), clip_fraction.mean(), ) (actor_loss, aux), actor_grads = jax.value_and_grad( actor_loss_fn, has_aux=True )(state.actor_params) lox.log({"actor/gradient_norm": optax.global_norm(actor_grads)}) actor_updates, actor_optimizer_state = self.actor_optimizer.update( actor_grads, state.actor_optimizer_state, state.actor_params ) actor_params = optax.apply_updates(state.actor_params, actor_updates) state = state.replace( actor_params=actor_params, actor_optimizer_state=actor_optimizer_state, ) return state, actor_loss.mean(), aux def _update_critic( self, key: Key, state: MAPPOState, initial_critic_carry: Carry, transitions: Transition, ) -> tuple[MAPPOState, Array]: torso_key, dropout_key = jax.random.split(key) if self.cfg.burn_in_length > 0: burn_in = jax.tree.map( lambda x: x[:, :, : self.cfg.burn_in_length], transitions ) initial_critic_carry, (_, _) = self.critic_network.apply( jax.lax.stop_gradient(state.critic_params), *burn_in.first, burn_in.first.done, initial_critic_carry, ) initial_critic_carry = jax.lax.stop_gradient(initial_critic_carry) transitions = jax.tree.map( lambda x: x[:, :, self.cfg.burn_in_length :], transitions ) returns = transitions.aux["returns"] def critic_loss_fn(params: PyTree): _, (values, aux) = self.critic_network.apply( params, *transitions.first, transitions.first.done, initial_critic_carry, rngs={"torso": torso_key, "dropout": dropout_key}, ) values = remove_feature_axis(values) critic_loss = self.critic_network.head.loss( values, aux, returns, transitions=transitions ) if self.cfg.clip_value_loss: clipped_value = transitions.aux["value"] + jnp.clip( (values - transitions.aux["value"]), -self.cfg.clip_coefficient, self.cfg.clip_coefficient, ) clipped_critic_loss = self.critic_network.head.loss( clipped_value, aux, returns, transitions=transitions ) critic_loss = jnp.maximum(critic_loss, clipped_critic_loss) critic_loss = critic_loss.mean() return critic_loss, values (critic_loss, values), critic_grads = jax.value_and_grad( critic_loss_fn, has_aux=True )(state.critic_params) explained_variance = 1 - jnp.var(returns - values) / jnp.var(returns) lox.log( { "critic/gradient_norm": optax.global_norm(critic_grads), "critic/explained_variance": explained_variance, "critic/value": values.mean(), } ) critic_updates, critic_optimizer_state = self.critic_optimizer.update( critic_grads, state.critic_optimizer_state, state.critic_params ) critic_params = optax.apply_updates(state.critic_params, critic_updates) state = state.replace( critic_params=critic_params, critic_optimizer_state=critic_optimizer_state, ) return state, critic_loss.mean() def _update_minibatch( self, state: MAPPOState, xs: tuple ) -> tuple[MAPPOState, tuple[Array, Array, tuple[Array, Array, Array]]]: minibatch, key = xs ( initial_actor_carry, initial_critic_carry, transitions, ) = minibatch actor_key, critic_key = jax.random.split(key) state, critic_loss = self._update_critic( critic_key, state, initial_critic_carry, transitions ) state, actor_loss, aux = self._update_actor( actor_key, state, initial_actor_carry, transitions ) return state, (actor_loss, critic_loss, aux) def _update_epoch(self, carry: tuple, key: Key) -> tuple: ( state, initial_actor_carry, initial_critic_carry, transitions, ) = carry permutation_key, minibatch_key = jax.random.split(key) batch = ( initial_actor_carry, initial_critic_carry, transitions, ) def shuffle(batch: PyTree): shuffle_time_axis = ( initial_actor_carry is None or initial_critic_carry is None ) num_agents = self.env.num_agents num_envs = self.cfg.num_envs num_steps = self.cfg.num_steps if shuffle_time_axis: batch = ( initial_actor_carry, initial_critic_carry, jax.tree.map( lambda x: x.reshape(num_agents, -1, 1, *x.shape[3:]), transitions, ), ) num_samples_per_agent = num_envs * num_steps else: num_samples_per_agent = num_envs permutation = jax.random.permutation(permutation_key, num_samples_per_agent) minibatches = jax.tree.map( lambda x: ( jnp.moveaxis( jnp.take(x, permutation, axis=1).reshape( num_agents, self.cfg.num_minibatches, -1, *x.shape[2:] ), 1, 0, ) if x is not None else None ), tuple(batch), ) return minibatches minibatches = shuffle(batch) minibatch_keys = jax.random.split(minibatch_key, self.cfg.num_minibatches) state, ( actor_loss, critic_loss, (entropy, approximate_kl, clip_fraction), ) = jax.lax.scan( self._update_minibatch, state, (minibatches, minibatch_keys), ) metrics = jax.tree.map( lambda x: x.mean(), (actor_loss, critic_loss, entropy, approximate_kl, clip_fraction), ) return ( state, initial_actor_carry, initial_critic_carry, transitions, ), metrics def _update_step(self, state: MAPPOState, key: Key) -> tuple[MAPPOState, None]: step_key, epoch_key = jax.random.split(key) initial_actor_carry = state.actor_carry initial_critic_carry = state.critic_carry step_keys = jax.random.split(step_key, self.cfg.num_steps) state, transitions = jax.lax.scan( partial(self._step, policy=self._stochastic_action), state, step_keys, ) ts = to_sequence(state.timestep) _, (value, _) = self.critic_network.apply( state.critic_params, *ts, ts.done, state.critic_carry, ) value = jax.vmap(remove_time_axis)(value) value = remove_feature_axis(value) _, advantages = jax.lax.scan( self._generalized_advantage_estimation, (jnp.zeros_like(value), value), transitions, reverse=True, unroll=16, ) returns = advantages + transitions.aux["value"] transitions = transitions.replace( aux={**transitions.aux, "advantages": advantages, "returns": returns} ) transitions = jax.tree.map( lambda x: jnp.moveaxis(x, 0, min(2, x.ndim - 1)), transitions ) epoch_keys = jax.random.split(epoch_key, self.cfg.update_epochs) (state, *_, transitions), metrics = jax.lax.scan( self._update_epoch, ( state, initial_actor_carry, initial_critic_carry, transitions, ), epoch_keys, ) actor_loss, critic_loss, entropy, approximate_kl, clip_fraction = jax.tree.map( lambda x: x.mean(), metrics ) lox.log( { "actor/loss": actor_loss, "critic/loss": critic_loss, "actor/entropy": entropy, "actor/approximate_kl": approximate_kl, "actor/clip_fraction": clip_fraction, } ) return state.replace(update_step=state.update_step + 1), None
[docs] def init(self, key: Key) -> MAPPOState: ( env_key, actor_key, actor_torso_key, actor_dropout_key, critic_key, critic_torso_key, critic_dropout_key, ) = jax.random.split(key, 7) agent_ids = self.env.agents num_agents = self.env.num_agents env_keys = jax.random.split(env_key, self.cfg.num_envs) obs, env_state = jax.vmap(self.env.reset, out_axes=(1, 0))(env_keys) action_space = self.env.action_spaces[agent_ids[0]] action = jnp.zeros( (num_agents, self.cfg.num_envs, *action_space.shape), dtype=action_space.dtype, ) reward = jnp.zeros((num_agents, self.cfg.num_envs), dtype=jnp.float32) done = jnp.ones((num_agents, self.cfg.num_envs), dtype=jnp.bool_) timestep = to_sequence( Timestep(obs=obs, action=action, reward=reward, done=done) ) actor_carry = self.actor_network.initialize_carry( (num_agents, self.cfg.num_envs, None) ) actor_params = self.actor_network.init( { "params": actor_key, "torso": actor_torso_key, "dropout": actor_dropout_key, }, *timestep, timestep.done, actor_carry, ) critic_carry = self.critic_network.initialize_carry( (num_agents, self.cfg.num_envs, None) ) critic_params = self.critic_network.init( { "params": critic_key, "torso": critic_torso_key, "dropout": critic_dropout_key, }, *timestep, timestep.done, critic_carry, ) return MAPPOState( step=0, update_step=0, timestep=from_sequence(timestep), env_state=env_state, actor_params=actor_params, critic_params=critic_params, actor_optimizer_state=self.actor_optimizer.init(actor_params), critic_optimizer_state=self.critic_optimizer.init(critic_params), actor_carry=actor_carry, critic_carry=critic_carry, )
[docs] def warmup(self, key: Key, state: MAPPOState, num_steps: int) -> MAPPOState: return state
[docs] def train(self, key: Key, state: MAPPOState, num_steps: int) -> MAPPOState: keys = jax.random.split( key, num_steps // (self.cfg.num_steps * self.cfg.num_envs) ) state, _ = jax.lax.scan( self._update_step, state, keys, ) return state
[docs] def evaluate(self, key: Key, state: MAPPOState, num_steps: int) -> MAPPOState: reset_key, eval_key = jax.random.split(key) num_agents = self.env.num_agents reset_keys = jax.random.split(reset_key, self.cfg.num_envs) obs, env_state = jax.vmap(self.env.reset, out_axes=(1, 0))(reset_keys) action_space = self.env.action_spaces[self.env.agents[0]] action = jnp.zeros( (num_agents, self.cfg.num_envs, *action_space.shape), dtype=action_space.dtype, ) reward = jnp.zeros((num_agents, self.cfg.num_envs), dtype=jnp.float32) done = jnp.ones((num_agents, self.cfg.num_envs), dtype=jnp.bool_) actor_carry = self.actor_network.initialize_carry( (num_agents, self.cfg.num_envs, None) ) critic_carry = self.critic_network.initialize_carry( (num_agents, self.cfg.num_envs, None) ) state = state.replace( timestep=Timestep(obs=obs, action=action, reward=reward, done=done), env_state=env_state, actor_carry=actor_carry, critic_carry=critic_carry, ) step_keys = jax.random.split(eval_key, num_steps) state, _ = jax.lax.scan( partial(self._step, policy=self._deterministic_action), state, step_keys, ) return state