from dataclasses import dataclass
from functools import partial
from typing import Any, Callable
import flax.linen as nn
import jax
import jax.numpy as jnp
import lox
import optax
from flax import core, struct
from memorax.utils import Timestep, Transition
from memorax.utils.axes import add_time_axis, remove_feature_axis, remove_time_axis
from memorax.utils.typing import (
Array,
Carry,
Environment,
EnvParams,
EnvState,
Key,
PyTree,
)
to_sequence = lambda timestep: jax.tree.map(
lambda x: jax.vmap(add_time_axis)(x), timestep
)
from_sequence = lambda timestep: jax.tree.map(
lambda x: jax.vmap(remove_time_axis)(x), timestep
)
[docs]
@struct.dataclass(frozen=True)
class MAPPOConfig:
num_envs: int
num_steps: int
gamma: float
gae_lambda: float
num_minibatches: int
update_epochs: int
normalize_advantage: bool
clip_coefficient: float
clip_value_loss: bool
entropy_coefficient: float
burn_in_length: int = 0
@property
def batch_size(self):
return self.num_envs * self.num_steps
[docs]
@struct.dataclass(frozen=True)
class MAPPOState:
step: int
update_step: int
timestep: Timestep
env_state: EnvState
actor_params: core.FrozenDict[str, Any]
actor_optimizer_state: optax.OptState
actor_carry: Array
critic_params: core.FrozenDict[str, Any]
critic_optimizer_state: optax.OptState
critic_carry: Array
[docs]
@dataclass
class MAPPO:
cfg: MAPPOConfig
env: Environment
env_params: EnvParams
actor_network: nn.Module
critic_network: nn.Module
actor_optimizer: optax.GradientTransformation
critic_optimizer: optax.GradientTransformation
def __post_init__(self):
assert (
self.cfg.update_epochs >= 1
), f"update_epochs ({self.cfg.update_epochs}) must be >= 1"
assert (
self.cfg.batch_size % self.cfg.num_minibatches == 0
), f"num_envs * num_steps ({self.cfg.batch_size}) must be divisible by num_minibatches ({self.cfg.num_minibatches})"
def _deterministic_action(
self, key: Key, state: MAPPOState
) -> tuple[MAPPOState, Array, Array, None, dict]:
ts = to_sequence(state.timestep)
(actor_carry, (probs, _)), intermediates = self.actor_network.apply(
state.actor_params,
*ts,
ts.done,
state.actor_carry,
mutable=["intermediates"],
)
action = jnp.argmax(probs.logits, axis=-1)
log_prob = probs.log_prob(action)
action = jax.vmap(remove_time_axis)(action)
log_prob = jax.vmap(remove_time_axis)(log_prob)
state = state.replace(actor_carry=actor_carry)
return state, action, log_prob, None, intermediates
def _stochastic_action(
self, key: Key, state: MAPPOState
) -> tuple[MAPPOState, Array, Array, Array, dict]:
action_key, actor_torso_key, critic_torso_key = jax.random.split(key, 3)
ts = to_sequence(state.timestep)
(actor_carry, (probs, _)), intermediates = self.actor_network.apply(
state.actor_params,
*ts,
ts.done,
state.actor_carry,
rngs={"torso": actor_torso_key},
mutable=["intermediates"],
)
action_keys = jax.random.split(action_key, self.env.num_agents)
sampled_action, log_prob = jax.vmap(lambda p, k: p.sample_and_log_prob(seed=k))(
probs, action_keys
)
critic_carry, (value, _) = self.critic_network.apply(
state.critic_params,
*ts,
ts.done,
state.critic_carry,
rngs={"torso": critic_torso_key},
)
action = jax.vmap(remove_time_axis)(sampled_action)
log_prob = jax.vmap(remove_time_axis)(log_prob)
value = jax.vmap(remove_time_axis)(value)
value = remove_feature_axis(value)
state = state.replace(actor_carry=actor_carry, critic_carry=critic_carry)
return state, action, log_prob, value, intermediates
def _generalized_advantage_estimation(self, carry: tuple, transition: Transition):
advantage, next_value = carry
delta = (
transition.second.reward
+ self.cfg.gamma * (1 - transition.second.done) * next_value
- transition.aux["value"]
)
advantage = (
delta
+ self.cfg.gamma
* self.cfg.gae_lambda
* (1 - transition.second.done)
* advantage
)
return (advantage, transition.aux["value"]), advantage
def _step(
self, state: MAPPOState, key: Key, *, policy: Callable
) -> tuple[MAPPOState, Transition]:
action_key, step_key = jax.random.split(key)
state, action, log_prob, value, intermediates = policy(action_key, state)
_, num_envs, *_ = state.timestep.obs.shape
step_keys = jax.random.split(step_key, num_envs)
next_obs, env_state, reward, done, info = jax.vmap(
self.env.step, in_axes=(0, 0, 1), out_axes=(1, 0, 1, 1, 0)
)(step_keys, state.env_state, action)
intermediates = jax.tree.map(
lambda x: jnp.mean(jnp.stack(x)),
intermediates.get("intermediates", {}),
is_leaf=lambda x: isinstance(x, tuple),
)
broadcast_dims = tuple(
range(state.timestep.done.ndim, state.timestep.action.ndim)
)
first = Timestep(
obs=state.timestep.obs,
action=state.timestep.action,
reward=state.timestep.reward,
done=state.timestep.done,
)
second = Timestep(
obs=next_obs,
action=action,
reward=reward,
done=done,
)
lox.log({"info": info, "intermediates": intermediates})
transition = Transition(
first=first,
second=second,
aux={"log_prob": log_prob, "value": value},
)
state = state.replace(
step=state.step + num_envs,
timestep=Timestep(
obs=next_obs,
action=jnp.where(
jnp.expand_dims(done, axis=broadcast_dims),
jnp.zeros_like(action),
action,
),
reward=jnp.where(
done, 0, jnp.asarray(reward, dtype=jnp.float32)
),
done=done,
),
env_state=env_state,
)
return state, transition
def _update_actor(
self,
key: Key,
state: MAPPOState,
initial_actor_carry: Carry,
transitions: Transition,
) -> tuple[MAPPOState, Array, tuple[Array, Array, Array]]:
torso_key, dropout_key = jax.random.split(key)
if self.cfg.burn_in_length > 0:
burn_in = jax.tree.map(
lambda x: x[:, :, : self.cfg.burn_in_length], transitions
)
initial_actor_carry, (_, _) = self.actor_network.apply(
jax.lax.stop_gradient(state.actor_params),
*burn_in.first,
burn_in.first.done,
initial_actor_carry,
)
initial_actor_carry = jax.lax.stop_gradient(initial_actor_carry)
transitions = jax.tree.map(
lambda x: x[:, :, self.cfg.burn_in_length :], transitions
)
advantages = transitions.aux["advantages"]
if self.cfg.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
def actor_loss_fn(params: PyTree):
_, (probs, _) = self.actor_network.apply(
params,
*transitions.first,
transitions.first.done,
initial_actor_carry,
rngs={"torso": torso_key, "dropout": dropout_key},
)
log_probs = probs.log_prob(transitions.second.action)
entropy = probs.entropy().mean()
ratio = jnp.exp(log_probs - transitions.aux["log_prob"])
approximate_kl = jnp.mean(transitions.aux["log_prob"] - log_probs)
clip_fraction = jnp.mean(
(jnp.abs(ratio - 1.0) > self.cfg.clip_coefficient).astype(jnp.float32)
)
actor_loss = -jnp.minimum(
ratio * advantages,
jnp.clip(
ratio,
1.0 - self.cfg.clip_coefficient,
1.0 + self.cfg.clip_coefficient,
)
* advantages,
).mean()
return actor_loss - self.cfg.entropy_coefficient * entropy, (
entropy.mean(),
approximate_kl.mean(),
clip_fraction.mean(),
)
(actor_loss, aux), actor_grads = jax.value_and_grad(
actor_loss_fn, has_aux=True
)(state.actor_params)
lox.log({"actor/gradient_norm": optax.global_norm(actor_grads)})
actor_updates, actor_optimizer_state = self.actor_optimizer.update(
actor_grads, state.actor_optimizer_state, state.actor_params
)
actor_params = optax.apply_updates(state.actor_params, actor_updates)
state = state.replace(
actor_params=actor_params,
actor_optimizer_state=actor_optimizer_state,
)
return state, actor_loss.mean(), aux
def _update_critic(
self,
key: Key,
state: MAPPOState,
initial_critic_carry: Carry,
transitions: Transition,
) -> tuple[MAPPOState, Array]:
torso_key, dropout_key = jax.random.split(key)
if self.cfg.burn_in_length > 0:
burn_in = jax.tree.map(
lambda x: x[:, :, : self.cfg.burn_in_length], transitions
)
initial_critic_carry, (_, _) = self.critic_network.apply(
jax.lax.stop_gradient(state.critic_params),
*burn_in.first,
burn_in.first.done,
initial_critic_carry,
)
initial_critic_carry = jax.lax.stop_gradient(initial_critic_carry)
transitions = jax.tree.map(
lambda x: x[:, :, self.cfg.burn_in_length :], transitions
)
returns = transitions.aux["returns"]
def critic_loss_fn(params: PyTree):
_, (values, aux) = self.critic_network.apply(
params,
*transitions.first,
transitions.first.done,
initial_critic_carry,
rngs={"torso": torso_key, "dropout": dropout_key},
)
values = remove_feature_axis(values)
critic_loss = self.critic_network.head.loss(
values, aux, returns, transitions=transitions
)
if self.cfg.clip_value_loss:
clipped_value = transitions.aux["value"] + jnp.clip(
(values - transitions.aux["value"]),
-self.cfg.clip_coefficient,
self.cfg.clip_coefficient,
)
clipped_critic_loss = self.critic_network.head.loss(
clipped_value, aux, returns, transitions=transitions
)
critic_loss = jnp.maximum(critic_loss, clipped_critic_loss)
critic_loss = critic_loss.mean()
return critic_loss, values
(critic_loss, values), critic_grads = jax.value_and_grad(
critic_loss_fn, has_aux=True
)(state.critic_params)
explained_variance = 1 - jnp.var(returns - values) / jnp.var(returns)
lox.log(
{
"critic/gradient_norm": optax.global_norm(critic_grads),
"critic/explained_variance": explained_variance,
"critic/value": values.mean(),
}
)
critic_updates, critic_optimizer_state = self.critic_optimizer.update(
critic_grads, state.critic_optimizer_state, state.critic_params
)
critic_params = optax.apply_updates(state.critic_params, critic_updates)
state = state.replace(
critic_params=critic_params,
critic_optimizer_state=critic_optimizer_state,
)
return state, critic_loss.mean()
def _update_minibatch(
self, state: MAPPOState, xs: tuple
) -> tuple[MAPPOState, tuple[Array, Array, tuple[Array, Array, Array]]]:
minibatch, key = xs
(
initial_actor_carry,
initial_critic_carry,
transitions,
) = minibatch
actor_key, critic_key = jax.random.split(key)
state, critic_loss = self._update_critic(
critic_key, state, initial_critic_carry, transitions
)
state, actor_loss, aux = self._update_actor(
actor_key, state, initial_actor_carry, transitions
)
return state, (actor_loss, critic_loss, aux)
def _update_epoch(self, carry: tuple, key: Key) -> tuple:
(
state,
initial_actor_carry,
initial_critic_carry,
transitions,
) = carry
permutation_key, minibatch_key = jax.random.split(key)
batch = (
initial_actor_carry,
initial_critic_carry,
transitions,
)
def shuffle(batch: PyTree):
shuffle_time_axis = (
initial_actor_carry is None or initial_critic_carry is None
)
num_agents = self.env.num_agents
num_envs = self.cfg.num_envs
num_steps = self.cfg.num_steps
if shuffle_time_axis:
batch = (
initial_actor_carry,
initial_critic_carry,
jax.tree.map(
lambda x: x.reshape(num_agents, -1, 1, *x.shape[3:]),
transitions,
),
)
num_samples_per_agent = num_envs * num_steps
else:
num_samples_per_agent = num_envs
permutation = jax.random.permutation(permutation_key, num_samples_per_agent)
minibatches = jax.tree.map(
lambda x: (
jnp.moveaxis(
jnp.take(x, permutation, axis=1).reshape(
num_agents, self.cfg.num_minibatches, -1, *x.shape[2:]
),
1,
0,
)
if x is not None
else None
),
tuple(batch),
)
return minibatches
minibatches = shuffle(batch)
minibatch_keys = jax.random.split(minibatch_key, self.cfg.num_minibatches)
state, (
actor_loss,
critic_loss,
(entropy, approximate_kl, clip_fraction),
) = jax.lax.scan(
self._update_minibatch,
state,
(minibatches, minibatch_keys),
)
metrics = jax.tree.map(
lambda x: x.mean(),
(actor_loss, critic_loss, entropy, approximate_kl, clip_fraction),
)
return (
state,
initial_actor_carry,
initial_critic_carry,
transitions,
), metrics
def _update_step(self, state: MAPPOState, key: Key) -> tuple[MAPPOState, None]:
step_key, epoch_key = jax.random.split(key)
initial_actor_carry = state.actor_carry
initial_critic_carry = state.critic_carry
step_keys = jax.random.split(step_key, self.cfg.num_steps)
state, transitions = jax.lax.scan(
partial(self._step, policy=self._stochastic_action),
state,
step_keys,
)
ts = to_sequence(state.timestep)
_, (value, _) = self.critic_network.apply(
state.critic_params,
*ts,
ts.done,
state.critic_carry,
)
value = jax.vmap(remove_time_axis)(value)
value = remove_feature_axis(value)
_, advantages = jax.lax.scan(
self._generalized_advantage_estimation,
(jnp.zeros_like(value), value),
transitions,
reverse=True,
unroll=16,
)
returns = advantages + transitions.aux["value"]
transitions = transitions.replace(
aux={**transitions.aux, "advantages": advantages, "returns": returns}
)
transitions = jax.tree.map(
lambda x: jnp.moveaxis(x, 0, min(2, x.ndim - 1)), transitions
)
epoch_keys = jax.random.split(epoch_key, self.cfg.update_epochs)
(state, *_, transitions), metrics = jax.lax.scan(
self._update_epoch,
(
state,
initial_actor_carry,
initial_critic_carry,
transitions,
),
epoch_keys,
)
actor_loss, critic_loss, entropy, approximate_kl, clip_fraction = jax.tree.map(
lambda x: x.mean(), metrics
)
lox.log(
{
"actor/loss": actor_loss,
"critic/loss": critic_loss,
"actor/entropy": entropy,
"actor/approximate_kl": approximate_kl,
"actor/clip_fraction": clip_fraction,
}
)
return state.replace(update_step=state.update_step + 1), None
[docs]
def init(self, key: Key) -> MAPPOState:
(
env_key,
actor_key,
actor_torso_key,
actor_dropout_key,
critic_key,
critic_torso_key,
critic_dropout_key,
) = jax.random.split(key, 7)
agent_ids = self.env.agents
num_agents = self.env.num_agents
env_keys = jax.random.split(env_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, out_axes=(1, 0))(env_keys)
action_space = self.env.action_spaces[agent_ids[0]]
action = jnp.zeros(
(num_agents, self.cfg.num_envs, *action_space.shape),
dtype=action_space.dtype,
)
reward = jnp.zeros((num_agents, self.cfg.num_envs), dtype=jnp.float32)
done = jnp.ones((num_agents, self.cfg.num_envs), dtype=jnp.bool_)
timestep = to_sequence(
Timestep(obs=obs, action=action, reward=reward, done=done)
)
actor_carry = self.actor_network.initialize_carry(
(num_agents, self.cfg.num_envs, None)
)
actor_params = self.actor_network.init(
{
"params": actor_key,
"torso": actor_torso_key,
"dropout": actor_dropout_key,
},
*timestep,
timestep.done,
actor_carry,
)
critic_carry = self.critic_network.initialize_carry(
(num_agents, self.cfg.num_envs, None)
)
critic_params = self.critic_network.init(
{
"params": critic_key,
"torso": critic_torso_key,
"dropout": critic_dropout_key,
},
*timestep,
timestep.done,
critic_carry,
)
return MAPPOState(
step=0,
update_step=0,
timestep=from_sequence(timestep),
env_state=env_state,
actor_params=actor_params,
critic_params=critic_params,
actor_optimizer_state=self.actor_optimizer.init(actor_params),
critic_optimizer_state=self.critic_optimizer.init(critic_params),
actor_carry=actor_carry,
critic_carry=critic_carry,
)
[docs]
def warmup(self, key: Key, state: MAPPOState, num_steps: int) -> MAPPOState:
return state
[docs]
def train(self, key: Key, state: MAPPOState, num_steps: int) -> MAPPOState:
keys = jax.random.split(
key, num_steps // (self.cfg.num_steps * self.cfg.num_envs)
)
state, _ = jax.lax.scan(
self._update_step,
state,
keys,
)
return state
[docs]
def evaluate(self, key: Key, state: MAPPOState, num_steps: int) -> MAPPOState:
reset_key, eval_key = jax.random.split(key)
num_agents = self.env.num_agents
reset_keys = jax.random.split(reset_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, out_axes=(1, 0))(reset_keys)
action_space = self.env.action_spaces[self.env.agents[0]]
action = jnp.zeros(
(num_agents, self.cfg.num_envs, *action_space.shape),
dtype=action_space.dtype,
)
reward = jnp.zeros((num_agents, self.cfg.num_envs), dtype=jnp.float32)
done = jnp.ones((num_agents, self.cfg.num_envs), dtype=jnp.bool_)
actor_carry = self.actor_network.initialize_carry(
(num_agents, self.cfg.num_envs, None)
)
critic_carry = self.critic_network.initialize_carry(
(num_agents, self.cfg.num_envs, None)
)
state = state.replace(
timestep=Timestep(obs=obs, action=action, reward=reward, done=done),
env_state=env_state,
actor_carry=actor_carry,
critic_carry=critic_carry,
)
step_keys = jax.random.split(eval_key, num_steps)
state, _ = jax.lax.scan(
partial(self._step, policy=self._deterministic_action),
state,
step_keys,
)
return state