from dataclasses import dataclass
from functools import partial
from typing import Any, Callable
import flax.linen as nn
import jax
import jax.numpy as jnp
import lox
import optax
from flax import core, struct
from memorax.utils import Timestep, Transition, periodic_incremental_update, utils
from memorax.utils.axes import remove_time_axis
from memorax.utils.typing import (
Array,
Buffer,
BufferState,
Carry,
Environment,
EnvParams,
EnvState,
Key,
PyTree,
)
[docs]
@struct.dataclass(frozen=True)
class SACConfig:
num_envs: int
gamma: float
tau: float
train_frequency: int
target_update_frequency: int
target_entropy_scale: float
gradient_steps: int = 1
burn_in_length: int = 0
[docs]
@struct.dataclass(frozen=True)
class SACState:
step: int
update_step: int
timestep: Timestep
env_state: EnvState
buffer_state: BufferState
actor_params: core.FrozenDict[str, Any]
critic_params: core.FrozenDict[str, Any]
critic_target_params: core.FrozenDict[str, Any]
alpha_params: core.FrozenDict[str, Any]
actor_optimizer_state: optax.OptState
critic_optimizer_state: optax.OptState
alpha_optimizer_state: optax.OptState
actor_carry: Array
critic_carry: Array
[docs]
@dataclass
class SAC:
cfg: SACConfig
env: Environment
env_params: EnvParams
actor_network: nn.Module
critic_network: nn.Module
alpha_network: nn.Module
actor_optimizer: optax.GradientTransformation
critic_optimizer: optax.GradientTransformation
alpha_optimizer: optax.GradientTransformation
buffer: Buffer
def __post_init__(self):
assert (
self.cfg.train_frequency >= self.cfg.num_envs
), f"train_frequency ({self.cfg.train_frequency}) must be >= num_envs ({self.cfg.num_envs})"
assert (
self.cfg.train_frequency % self.cfg.num_envs == 0
), f"train_frequency ({self.cfg.train_frequency}) must be divisible by num_envs ({self.cfg.num_envs})"
assert (
self.cfg.gradient_steps >= 1
), f"gradient_steps ({self.cfg.gradient_steps}) must be >= 1"
def _deterministic_action(self, key: Key, state: SACState):
sample_key = key
obs, done, action, reward = state.timestep.to_sequence()
(next_carry, (dist, _)), intermediates = self.actor_network.apply(
state.actor_params,
observation=obs,
done=done,
action=action,
reward=reward,
initial_carry=state.actor_carry,
temperature=0.0,
mutable=["intermediates"],
)
action = dist.sample(seed=sample_key)
action = remove_time_axis(action)
state = state.replace(actor_carry=next_carry)
return state, action, intermediates
def _stochastic_action(self, key: Key, state: SACState):
sample_key = key
obs, done, action, reward = state.timestep.to_sequence()
(next_carry, (dist, _)), intermediates = self.actor_network.apply(
state.actor_params,
observation=obs,
done=done,
action=action,
reward=reward,
initial_carry=state.actor_carry,
mutable=["intermediates"],
)
action = dist.sample(seed=sample_key)
action = remove_time_axis(action)
state = state.replace(actor_carry=next_carry)
return state, action, intermediates
def _random_action(self, key: Key, state: SACState):
action_keys = jax.random.split(key, self.cfg.num_envs)
action = jax.vmap(self.env.action_space(self.env_params).sample)(action_keys)
return state, action, {}
def _step(
self,
state: SACState,
key: Key,
*,
policy: Callable,
):
initial_carry = state.actor_carry
action_key, step_key = jax.random.split(key)
state, action, intermediates = policy(action_key, state)
num_envs, *_ = state.timestep.obs.shape
step_key = jax.random.split(step_key, num_envs)
next_obs, env_state, reward, done, info = jax.vmap(
self.env.step, in_axes=(0, 0, 0, None)
)(step_key, state.env_state, action, self.env_params)
intermediates = jax.tree.map(
lambda x: jnp.mean(jnp.stack(x)),
intermediates.get("intermediates", {}),
is_leaf=lambda x: isinstance(x, tuple),
)
broadcast_dims = tuple(
range(state.timestep.done.ndim, state.timestep.action.ndim)
)
first = Timestep(
obs=state.timestep.obs,
action=state.timestep.action,
reward=state.timestep.reward,
done=state.timestep.done,
)
second = Timestep(
obs=next_obs,
action=action,
reward=reward,
done=done,
)
lox.log({"info": info, "intermediates": intermediates})
transition = Transition(
first=first,
second=second,
carry=initial_carry,
)
buffer_transition = jax.tree.map(lambda x: jnp.expand_dims(x, 1), transition)
buffer_state = self.buffer.add(state.buffer_state, buffer_transition)
next_reward = jnp.asarray(reward, dtype=jnp.float32)
state = state.replace(
step=state.step + self.cfg.num_envs,
timestep=Timestep(
obs=next_obs,
action=jnp.where(
jnp.expand_dims(done, axis=broadcast_dims),
jnp.zeros_like(action),
action,
),
reward=jnp.where(done, jnp.zeros_like(next_reward), next_reward),
done=done,
),
env_state=env_state,
buffer_state=buffer_state,
)
return state, transition
[docs]
def init(self, key: Key):
env_key, actor_key, actor_torso_key, critic_key, critic_torso_key, alpha_key = (
jax.random.split(key, 6)
)
env_keys = jax.random.split(env_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
env_keys, self.env_params
)
action_space = self.env.action_space(self.env_params)
action = jnp.zeros(
(self.cfg.num_envs, *action_space.shape), dtype=action_space.dtype
)
reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32)
done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool_)
timestep = Timestep(
obs=obs, action=action, reward=reward, done=done
).to_sequence()
actor_carry = self.actor_network.initialize_carry((self.cfg.num_envs, None))
actor_params = self.actor_network.init(
{"params": actor_key, "torso": actor_torso_key},
*timestep,
initial_carry=actor_carry,
)
actor_optimizer_state = self.actor_optimizer.init(actor_params)
critic_carry = self.critic_network.initialize_carry((self.cfg.num_envs, None))
critic_params = self.critic_network.init(
{"params": critic_key, "torso": critic_torso_key},
*timestep,
initial_carry=critic_carry,
)
critic_target_params = critic_params
critic_optimizer_state = self.critic_optimizer.init(critic_params)
alpha_params = self.alpha_network.init(alpha_key)
alpha_optimizer_state = self.alpha_optimizer.init(alpha_params)
timestep = timestep.from_sequence()
transition = Transition(
first=timestep,
second=timestep,
carry=actor_carry,
)
buffer_state = self.buffer.init(jax.tree.map(lambda x: x[0], transition))
return SACState(
step=0,
update_step=0,
timestep=timestep,
actor_carry=actor_carry,
critic_carry=critic_carry,
env_state=env_state,
buffer_state=buffer_state,
actor_params=actor_params,
critic_params=critic_params,
critic_target_params=critic_target_params,
alpha_params=alpha_params,
actor_optimizer_state=actor_optimizer_state,
critic_optimizer_state=critic_optimizer_state,
alpha_optimizer_state=alpha_optimizer_state,
)
def _update_alpha(
self,
key: Key,
state: SACState,
experience: PyTree,
initial_actor_carry: Carry = None,
):
action_dim, *_ = self.env.action_space(self.env_params).shape
target_entropy = -self.cfg.target_entropy_scale * action_dim
first_obs, first_done, first_action, first_reward = experience.first
_, (dist, _) = self.actor_network.apply(
state.actor_params,
observation=first_obs,
done=first_done,
action=first_action,
reward=first_reward,
initial_carry=initial_actor_carry,
)
key, sample_key = jax.random.split(key)
_, log_probs = dist.sample_and_log_prob(seed=sample_key)
def alpha_loss_fn(alpha_params: PyTree):
log_alpha = self.alpha_network.apply(alpha_params)
alpha = jnp.exp(log_alpha)
alpha_loss = (alpha * (-log_probs - target_entropy)).mean()
return alpha_loss, (alpha, alpha_loss)
(_, (alpha, alpha_loss)), grads = jax.value_and_grad(
alpha_loss_fn, has_aux=True
)(state.alpha_params)
lox.log(
{
"alpha/loss": alpha_loss,
"alpha/value": alpha,
"alpha/gradient_norm": optax.global_norm(grads),
}
)
updates, optimizer_state = self.alpha_optimizer.update(
grads, state.alpha_optimizer_state, state.alpha_params
)
alpha_params = optax.apply_updates(state.alpha_params, updates)
state = state.replace(
alpha_params=alpha_params, alpha_optimizer_state=optimizer_state
)
return state
def _update_actor(
self,
key: Key,
state: SACState,
experience: PyTree,
initial_actor_carry: Carry = None,
initial_critic_carry: Carry = None,
):
log_alpha = self.alpha_network.apply(state.alpha_params)
alpha = jnp.exp(log_alpha)
first_obs, first_done, first_action, first_reward = experience.first
def actor_loss_fn(actor_params):
carry, (dist, _) = self.actor_network.apply(
actor_params,
observation=first_obs,
done=first_done,
action=first_action,
reward=first_reward,
initial_carry=initial_actor_carry,
)
actions, log_probs = dist.sample_and_log_prob(seed=key)
_, (qs, _) = self.critic_network.apply(
state.critic_params,
observation=first_obs,
done=first_done,
action=actions,
reward=first_reward,
initial_carry=initial_critic_carry,
)
q = jnp.minimum(*qs)
actor_loss = (log_probs * alpha - q).mean()
return actor_loss, (carry, actor_loss, log_probs)
(_, (carry, actor_loss, log_probs)), grads = jax.value_and_grad(
actor_loss_fn, has_aux=True
)(state.actor_params)
lox.log(
{
"actor/loss": actor_loss,
"actor/entropy": -log_probs.mean(),
"actor/gradient_norm": optax.global_norm(grads),
}
)
updates, actor_optimizer_state = self.actor_optimizer.update(
grads, state.actor_optimizer_state, state.actor_params
)
actor_params = optax.apply_updates(state.actor_params, updates)
state = state.replace(
actor_params=actor_params, actor_optimizer_state=actor_optimizer_state
)
return state, carry
def _update_critic(
self,
key: Key,
state: SACState,
experience: PyTree,
initial_actor_carry: Carry = None,
initial_critic_carry: Carry = None,
initial_target_critic_carry: Carry = None,
):
second_obs, second_done, second_action, second_reward = experience.second
_, (dist, _) = self.actor_network.apply(
state.actor_params,
observation=second_obs,
done=second_done,
action=second_action,
reward=second_reward,
initial_carry=initial_actor_carry,
)
next_actions, next_log_probs = dist.sample_and_log_prob(seed=key)
_, (next_qs, _) = self.critic_network.apply(
state.critic_target_params,
observation=second_obs,
done=second_done,
action=next_actions,
reward=second_reward,
initial_carry=initial_target_critic_carry,
)
next_q = jnp.minimum(*next_qs)
log_alpha = self.alpha_network.apply(state.alpha_params)
alpha = jnp.exp(log_alpha)
next_value = next_q - alpha * next_log_probs
target_q = (
experience.second.reward
+ self.cfg.gamma * (1 - experience.second.done) * next_value
)
target_q = jax.lax.stop_gradient(target_q)
first_obs, first_done, first_action, first_reward = experience.first
def critic_loss_fn(critic_params):
_, (qs, _) = self.critic_network.apply(
critic_params,
observation=first_obs,
done=first_done,
action=second_action,
reward=first_reward,
initial_carry=initial_critic_carry,
)
q1, q2 = qs
critic_loss = (
self.critic_network.head.loss(
q1, {}, target_q, transitions=experience
).mean()
+ self.critic_network.head.loss(
q2, {}, target_q, transitions=experience
).mean()
)
return critic_loss, (critic_loss, q1, q2)
(_, (critic_loss, q1, q2)), grads = jax.value_and_grad(
critic_loss_fn, has_aux=True
)(state.critic_params)
lox.log(
{
"critic/loss": critic_loss,
"critic/q1": q1.mean(),
"critic/q2": q2.mean(),
"critic/gradient_norm": optax.global_norm(grads),
}
)
updates, critic_optimizer_state = self.critic_optimizer.update(
grads, state.critic_optimizer_state, state.critic_params
)
critic_params = optax.apply_updates(state.critic_params, updates)
critic_target_params = periodic_incremental_update(
critic_params,
state.critic_target_params,
state.step,
self.cfg.target_update_frequency,
self.cfg.tau,
)
state = state.replace(
critic_params=critic_params,
critic_target_params=critic_target_params,
critic_optimizer_state=critic_optimizer_state,
)
return state
def _update(self, key: Key, state: SACState):
batch_key, critic_key, actor_key, alpha_key = jax.random.split(key, 4)
experience = self.buffer.sample(state.buffer_state, batch_key).experience
experience = jax.tree.map(lambda x: jnp.expand_dims(x, 1), experience)
initial_actor_carry = None
initial_critic_carry = None
initial_target_critic_carry = None
if experience.carry is not None:
initial_actor_carry = jax.tree.map(lambda x: x[:, 0], experience.carry)
initial_actor_carry = utils.burn_in(
self.actor_network,
state.actor_params,
experience.first,
initial_actor_carry,
self.cfg.burn_in_length,
)
initial_critic_carry = utils.burn_in(
self.critic_network,
state.critic_params,
experience.first,
initial_critic_carry,
self.cfg.burn_in_length,
)
initial_target_critic_carry = utils.burn_in(
self.critic_network,
state.critic_target_params,
experience.second,
initial_target_critic_carry,
self.cfg.burn_in_length,
)
experience = jax.tree.map(lambda x: x[:, self.cfg.burn_in_length :], experience)
state = self._update_critic(
critic_key,
state,
experience,
initial_actor_carry,
initial_critic_carry,
initial_target_critic_carry,
)
state, actor_carry = self._update_actor(
actor_key,
state,
experience,
initial_actor_carry,
initial_critic_carry,
)
state = self._update_alpha(alpha_key, state, experience, initial_actor_carry)
return state
def _update_step(self, state: SACState, key: Key):
step_key, gradient_key = jax.random.split(key)
step_keys = jax.random.split(
step_key, self.cfg.train_frequency // self.cfg.num_envs
)
state, transitions = jax.lax.scan(
partial(self._step, policy=self._stochastic_action),
state,
step_keys,
)
gradient_keys = jax.random.split(gradient_key, self.cfg.gradient_steps)
state, _ = jax.lax.scan(
lambda state, key: (self._update(key, state), None),
state,
gradient_keys,
)
return state.replace(update_step=state.update_step + 1), None
[docs]
def warmup(self, key: Key, state: SACState, num_steps: int) -> SACState:
step_keys = jax.random.split(key, num_steps // self.cfg.num_envs)
state, _ = jax.lax.scan(
partial(self._step, policy=self._random_action),
state,
step_keys,
)
return state
[docs]
def train(self, key: Key, state: SACState, num_steps: int):
keys = jax.random.split(key, num_steps // self.cfg.train_frequency)
state, _ = jax.lax.scan(
self._update_step,
state,
keys,
)
return state
[docs]
def evaluate(self, key: Key, state: SACState, num_steps: int) -> SACState:
reset_key, eval_key = jax.random.split(key)
reset_key = jax.random.split(reset_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
reset_key, self.env_params
)
action_space = self.env.action_space(self.env_params)
action = jnp.zeros(
(self.cfg.num_envs, *action_space.shape), dtype=action_space.dtype
)
reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32)
done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool_)
timestep = Timestep(obs=obs, action=action, reward=reward, done=done)
carry = self.actor_network.initialize_carry((self.cfg.num_envs, None))
state = state.replace(
timestep=timestep,
env_state=env_state,
actor_carry=carry,
)
step_keys = jax.random.split(eval_key, num_steps)
state, _ = jax.lax.scan(
partial(
self._step,
policy=self._deterministic_action,
),
state,
step_keys,
)
return state