from dataclasses import dataclass
from functools import partial
from typing import Any, Callable
import flax.linen as nn
import jax
import jax.numpy as jnp
import lox
import optax
from flax import core, struct
from memorax.utils import Timestep, Transition, utils
from memorax.utils.axes import remove_feature_axis, remove_time_axis
from memorax.utils.typing import (
Array,
Carry,
Discrete,
Environment,
EnvParams,
EnvState,
Key,
PyTree,
)
[docs]
@struct.dataclass(frozen=True)
class PPOConfig:
num_envs: int
num_steps: int
gamma: float
gae_lambda: float
num_minibatches: int
update_epochs: int
normalize_advantage: bool
clip_coefficient: float
clip_value_loss: bool
entropy_coefficient: float
burn_in_length: int = 0
@property
def batch_size(self):
return self.num_envs * self.num_steps
[docs]
@struct.dataclass(frozen=True)
class PPOState:
step: int
update_step: int
timestep: Timestep
env_state: EnvState
actor_params: core.FrozenDict[str, Any]
actor_optimizer_state: optax.OptState
actor_carry: Array
critic_params: core.FrozenDict[str, Any]
critic_optimizer_state: optax.OptState
critic_carry: Array
[docs]
@dataclass
class PPO:
cfg: PPOConfig
env: Environment
env_params: EnvParams
actor_network: nn.Module
critic_network: nn.Module
actor_optimizer: optax.GradientTransformation
critic_optimizer: optax.GradientTransformation
def __post_init__(self):
assert (
self.cfg.update_epochs >= 1
), f"update_epochs ({self.cfg.update_epochs}) must be >= 1"
assert (
self.cfg.batch_size % self.cfg.num_minibatches == 0
), f"num_envs * num_steps ({self.cfg.batch_size}) must be divisible by num_minibatches ({self.cfg.num_minibatches})"
def _deterministic_action(
self, key: Key, state: PPOState
) -> tuple[PPOState, Array, Array, None, dict]:
(actor_carry, (probs, _)), intermediates = self.actor_network.apply(
state.actor_params,
*state.timestep.to_sequence(),
initial_carry=state.actor_carry,
mutable=["intermediates"],
)
action = (
jnp.argmax(probs.logits, axis=-1)
if isinstance(self.env.action_space(self.env_params), Discrete)
else probs.mode()
)
log_prob = probs.log_prob(action)
action = remove_time_axis(action)
log_prob = remove_time_axis(log_prob)
state = state.replace(
actor_carry=actor_carry,
)
return state, action, log_prob, None, intermediates
def _stochastic_action(
self, key: Key, state: PPOState
) -> tuple[PPOState, Array, Array, Array, dict]:
action_key, actor_torso_key, critic_torso_key = jax.random.split(key, 3)
ts = state.timestep.to_sequence()
(actor_carry, (probs, _)), intermediates = self.actor_network.apply(
state.actor_params,
*ts,
initial_carry=state.actor_carry,
rngs={"torso": actor_torso_key},
mutable=["intermediates"],
)
action, log_prob = probs.sample_and_log_prob(seed=action_key)
critic_carry, (value, _) = self.critic_network.apply(
state.critic_params,
*ts,
initial_carry=state.critic_carry,
rngs={"torso": critic_torso_key},
)
action = remove_time_axis(action)
log_prob = remove_time_axis(log_prob)
value = remove_time_axis(value)
value = remove_feature_axis(value)
state = state.replace(
actor_carry=actor_carry,
critic_carry=critic_carry,
)
return state, action, log_prob, value, intermediates
def _generalized_advantage_estimation(self, carry: tuple, transition: Transition):
advantage, next_value = carry
delta = (
transition.second.reward
+ self.cfg.gamma * (1 - transition.second.done) * next_value
- transition.aux["value"]
)
advantage = (
delta
+ self.cfg.gamma
* self.cfg.gae_lambda
* (1 - transition.second.done)
* advantage
)
return (advantage, transition.aux["value"]), advantage
def _step(
self, state: PPOState, key: Key, *, policy: Callable
) -> tuple[PPOState, Transition]:
action_key, step_key = jax.random.split(key)
state, action, log_prob, value, intermediates = policy(action_key, state)
num_envs, *_ = state.timestep.obs.shape
step_key = jax.random.split(step_key, num_envs)
next_obs, env_state, reward, done, info = jax.vmap(
self.env.step, in_axes=(0, 0, 0, None)
)(step_key, state.env_state, action, self.env_params)
intermediates = jax.tree.map(
lambda x: jnp.mean(jnp.stack(x)),
intermediates.get("intermediates", {}),
is_leaf=lambda x: isinstance(x, tuple),
)
broadcast_dims = tuple(
range(state.timestep.done.ndim, state.timestep.action.ndim)
)
first = Timestep(
obs=state.timestep.obs,
action=state.timestep.action,
reward=state.timestep.reward,
done=state.timestep.done,
)
second = Timestep(
obs=None,
action=action,
reward=reward,
done=done,
)
lox.log({"info": info, "intermediates": intermediates})
transition = Transition(
first=first,
second=second,
aux={"log_prob": log_prob, "value": value},
)
state = state.replace(
step=state.step + self.cfg.num_envs,
timestep=Timestep(
obs=next_obs,
action=jnp.where(
jnp.expand_dims(done, axis=broadcast_dims),
jnp.zeros_like(action),
action,
),
reward=jnp.where(
done, 0, jnp.asarray(reward, dtype=jnp.float32)
),
done=done,
),
env_state=env_state,
)
return state, transition
def _update_actor(
self,
key: Key,
state: PPOState,
initial_actor_carry: Carry,
transitions: Transition,
) -> tuple[PPOState, Array, tuple[Array, Array, Array]]:
torso_key, dropout_key = jax.random.split(key)
initial_actor_carry = utils.burn_in(
self.actor_network,
state.actor_params,
transitions.first,
initial_actor_carry,
self.cfg.burn_in_length,
)
transitions = jax.tree.map(
lambda x: x[:, self.cfg.burn_in_length :], transitions
)
advantages = transitions.aux["advantages"]
if self.cfg.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
def actor_loss_fn(params: PyTree):
_, (probs, _) = self.actor_network.apply(
params,
*transitions.first,
initial_carry=initial_actor_carry,
rngs={"torso": torso_key, "dropout": dropout_key},
)
log_probs = probs.log_prob(transitions.second.action)
entropy = probs.entropy().mean()
ratio = jnp.exp(log_probs - transitions.aux["log_prob"])
approximate_kl = jnp.mean(transitions.aux["log_prob"] - log_probs)
clip_fraction = jnp.mean(
(jnp.abs(ratio - 1.0) > self.cfg.clip_coefficient).astype(jnp.float32)
)
actor_loss = -jnp.minimum(
ratio * advantages,
jnp.clip(
ratio,
1.0 - self.cfg.clip_coefficient,
1.0 + self.cfg.clip_coefficient,
)
* advantages,
).mean()
return actor_loss - self.cfg.entropy_coefficient * entropy, (
entropy.mean(),
approximate_kl.mean(),
clip_fraction.mean(),
)
(actor_loss, aux), actor_grads = jax.value_and_grad(
actor_loss_fn, has_aux=True
)(state.actor_params)
lox.log({"actor/gradient_norm": optax.global_norm(actor_grads)})
actor_updates, actor_optimizer_state = self.actor_optimizer.update(
actor_grads, state.actor_optimizer_state, state.actor_params
)
actor_params = optax.apply_updates(state.actor_params, actor_updates)
state = state.replace(
actor_params=actor_params,
actor_optimizer_state=actor_optimizer_state,
)
return state, actor_loss.mean(), aux
def _update_critic(
self,
key: Key,
state: PPOState,
initial_critic_carry: Carry,
transitions: Transition,
) -> tuple[PPOState, Array]:
torso_key, dropout_key = jax.random.split(key)
initial_critic_carry = utils.burn_in(
self.critic_network,
state.critic_params,
transitions.first,
initial_critic_carry,
self.cfg.burn_in_length,
)
transitions = jax.tree.map(
lambda x: x[:, self.cfg.burn_in_length :], transitions
)
returns = transitions.aux["returns"]
def critic_loss_fn(params: PyTree):
_, (values, aux) = self.critic_network.apply(
params,
*transitions.first,
initial_carry=initial_critic_carry,
rngs={"torso": torso_key, "dropout": dropout_key},
)
values = remove_feature_axis(values)
critic_loss = self.critic_network.head.loss(
values, aux, returns, transitions=transitions
)
if self.cfg.clip_value_loss:
clipped_value = transitions.aux["value"] + jnp.clip(
(values - transitions.aux["value"]),
-self.cfg.clip_coefficient,
self.cfg.clip_coefficient,
)
clipped_critic_loss = self.critic_network.head.loss(
clipped_value, aux, returns, transitions=transitions
)
critic_loss = jnp.maximum(critic_loss, clipped_critic_loss)
critic_loss = critic_loss.mean()
return critic_loss, values
(critic_loss, values), critic_grads = jax.value_and_grad(
critic_loss_fn, has_aux=True
)(state.critic_params)
explained_variance = 1 - jnp.var(returns - values) / jnp.var(returns)
lox.log(
{
"critic/gradient_norm": optax.global_norm(critic_grads),
"critic/explained_variance": explained_variance,
"critic/value": values.mean(),
}
)
critic_updates, critic_optimizer_state = self.critic_optimizer.update(
critic_grads, state.critic_optimizer_state, state.critic_params
)
critic_params = optax.apply_updates(state.critic_params, critic_updates)
state = state.replace(
critic_params=critic_params, critic_optimizer_state=critic_optimizer_state
)
return state, critic_loss.mean()
def _update_minibatch(
self, state: PPOState, xs: tuple
) -> tuple[PPOState, tuple[Array, Array, tuple[Array, Array, Array]]]:
minibatch, key = xs
(
initial_actor_carry,
initial_critic_carry,
transitions,
) = minibatch
actor_key, critic_key = jax.random.split(key)
state, critic_loss = self._update_critic(
critic_key, state, initial_critic_carry, transitions
)
state, actor_loss, aux = self._update_actor(
actor_key, state, initial_actor_carry, transitions
)
return state, (actor_loss, critic_loss, aux)
def _update_epoch(self, carry: tuple, key: Key) -> tuple:
(
state,
initial_actor_carry,
initial_critic_carry,
transitions,
) = carry
permutation_key, minibatch_key = jax.random.split(key)
batch = (
initial_actor_carry,
initial_critic_carry,
transitions,
)
def shuffle(batch: PyTree):
shuffle_time_axis = (
initial_actor_carry is None or initial_critic_carry is None
)
num_permutations = self.cfg.num_envs
if shuffle_time_axis:
batch = (
initial_actor_carry,
initial_critic_carry,
jax.tree.map(
lambda x: x.reshape(-1, 1, *x.shape[2:]),
transitions,
),
)
num_permutations *= self.cfg.num_steps
permutation = jax.random.permutation(permutation_key, num_permutations)
minibatches = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0).reshape(
self.cfg.num_minibatches, -1, *x.shape[1:]
),
tuple(batch),
)
return minibatches
minibatches = shuffle(batch)
minibatch_keys = jax.random.split(minibatch_key, self.cfg.num_minibatches)
state, (
actor_loss,
critic_loss,
(entropy, approximate_kl, clip_fraction),
) = jax.lax.scan(
self._update_minibatch,
state,
(minibatches, minibatch_keys),
)
metrics = jax.tree.map(
lambda x: x.mean(),
(actor_loss, critic_loss, entropy, approximate_kl, clip_fraction),
)
return (
state,
initial_actor_carry,
initial_critic_carry,
transitions,
), metrics
def _update_step(self, state: PPOState, key: Key) -> tuple[PPOState, None]:
step_key, epoch_key = jax.random.split(key)
initial_actor_carry = state.actor_carry
initial_critic_carry = state.critic_carry
step_keys = jax.random.split(step_key, self.cfg.num_steps)
state, transitions = jax.lax.scan(
partial(self._step, policy=self._stochastic_action),
state,
step_keys,
)
_, (value, _) = self.critic_network.apply(
state.critic_params,
*state.timestep.to_sequence(),
initial_carry=state.critic_carry,
)
value = remove_time_axis(value)
value = remove_feature_axis(value)
_, advantages = jax.lax.scan(
self._generalized_advantage_estimation,
(jnp.zeros_like(value), value),
transitions,
reverse=True,
unroll=16,
)
returns = advantages + transitions.aux["value"]
transitions = transitions.replace(
aux={**transitions.aux, "advantages": advantages, "returns": returns}
)
transitions = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), transitions)
epoch_keys = jax.random.split(epoch_key, self.cfg.update_epochs)
(state, *_, transitions), metrics = jax.lax.scan(
self._update_epoch,
(
state,
initial_actor_carry,
initial_critic_carry,
transitions,
),
epoch_keys,
)
actor_loss, critic_loss, entropy, approximate_kl, clip_fraction = jax.tree.map(
lambda x: x.mean(), metrics
)
lox.log(
{
"actor/loss": actor_loss,
"critic/loss": critic_loss,
"actor/entropy": entropy,
"actor/approximate_kl": approximate_kl,
"actor/clip_fraction": clip_fraction,
}
)
return state.replace(update_step=state.update_step + 1), None
[docs]
def init(self, key: Key) -> PPOState:
(
env_key,
actor_key,
actor_torso_key,
actor_dropout_key,
critic_key,
critic_torso_key,
critic_dropout_key,
) = jax.random.split(key, 7)
env_keys = jax.random.split(env_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
env_keys, self.env_params
)
action = jnp.zeros(
(self.cfg.num_envs, *self.env.action_space(self.env_params).shape),
dtype=self.env.action_space(self.env_params).dtype,
)
reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32)
done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool_)
timestep = Timestep(
obs=obs, action=action, reward=reward, done=done
).to_sequence()
actor_carry = self.actor_network.initialize_carry((self.cfg.num_envs, None))
critic_carry = self.critic_network.initialize_carry((self.cfg.num_envs, None))
actor_params = self.actor_network.init(
{
"params": actor_key,
"torso": actor_torso_key,
"dropout": actor_dropout_key,
},
*timestep,
initial_carry=actor_carry,
)
critic_params = self.critic_network.init(
{
"params": critic_key,
"torso": critic_torso_key,
"dropout": critic_dropout_key,
},
*timestep,
initial_carry=critic_carry,
)
actor_optimizer_state = self.actor_optimizer.init(actor_params)
critic_optimizer_state = self.critic_optimizer.init(critic_params)
return PPOState(
step=0,
update_step=0,
timestep=timestep.from_sequence(),
actor_carry=actor_carry,
critic_carry=critic_carry,
env_state=env_state,
actor_params=actor_params,
critic_params=critic_params,
actor_optimizer_state=actor_optimizer_state,
critic_optimizer_state=critic_optimizer_state,
)
[docs]
def warmup(self, key: Key, state: PPOState, num_steps: int) -> PPOState:
return state
[docs]
def train(self, key: Key, state: PPOState, num_steps: int) -> PPOState:
keys = jax.random.split(
key, num_steps // (self.cfg.num_steps * self.cfg.num_envs)
)
state, _ = jax.lax.scan(
self._update_step,
state,
keys,
)
return state
[docs]
def evaluate(self, key: Key, state: PPOState, num_steps: int) -> PPOState:
reset_key, eval_key = jax.random.split(key)
reset_key = jax.random.split(reset_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
reset_key, self.env_params
)
action = jnp.zeros(
(self.cfg.num_envs, *self.env.action_space(self.env_params).shape),
dtype=self.env.action_space(self.env_params).dtype,
)
reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32)
done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool_)
timestep = Timestep(obs=obs, action=action, reward=reward, done=done)
initial_actor_carry = self.actor_network.initialize_carry(
(self.cfg.num_envs, None)
)
initial_critic_carry = self.critic_network.initialize_carry(
(self.cfg.num_envs, None)
)
state = state.replace(
timestep=timestep,
actor_carry=initial_actor_carry,
critic_carry=initial_critic_carry,
env_state=env_state,
)
step_keys = jax.random.split(eval_key, num_steps)
state, _ = jax.lax.scan(
partial(self._step, policy=self._deterministic_action),
state,
step_keys,
)
return state