Source code for memorax.algorithms.sac

from functools import partial
from typing import Any, Callable

import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax import core, struct

from memorax.utils import Transition, periodic_incremental_update
from memorax.utils.typing import (Array, Buffer, BufferState, Environment,
                                  EnvParams, EnvState, Key)


@struct.dataclass
class Batch:
    """Data structure for a batch of transitions sampled from the replay buffer."""

    obs: Array
    """Batch of obs with shape [batch_size, obs_dim]"""
    prev_done: Array
    """Batch of prev done flags with shape [batch_size]"""
    action: Array
    """Batch of actions with shape [batch_size, action_dim]"""
    reward: Array
    """Batch of rewards with shape [batch_size]"""
    next_obs: Array
    """Batch of next obs with shape [batch_size, obs_dim]"""
    done: Array
    """Batch of done flags with shape [batch_size]"""


[docs] @struct.dataclass(frozen=True) class SACConfig: """Configuration for SAC algorithm.""" name: str actor_lr: float critic_lr: float alpha_lr: float num_envs: int num_eval_envs: int buffer_size: int gamma: float tau: float train_frequency: int target_update_frequency: int batch_size: int initial_alpha: float target_entropy_scale: float learning_starts: int max_grad_norm: float mask: bool
[docs] @struct.dataclass(frozen=True) class SACState: step: int env_state: EnvState buffer_state: BufferState actor_params: core.FrozenDict[str, Any] critic_params: core.FrozenDict[str, Any] critic_target_params: core.FrozenDict[str, Any] alpha_params: core.FrozenDict[str, Any] actor_optimizer_state: optax.OptState critic_optimizer_state: optax.OptState alpha_optimizer_state: optax.OptState obs: Array done: Array actor_hidden_state: Array
[docs] @struct.dataclass(frozen=True) class SAC: cfg: SACConfig env: Environment env_params: EnvParams actor_network: nn.Module critic_network: nn.Module alpha_network: nn.Module actor_optimizer: optax.GradientTransformation critic_optimizer: optax.GradientTransformation alpha_optimizer: optax.GradientTransformation buffer: Buffer def _deterministic_action(self, key, state: SACState): key, sample_key = jax.random.split(key) next_hidden_state, dist = self.actor_network.apply( state.actor_params, jnp.expand_dims(state.obs, 1), mask=jnp.expand_dims(state.done, 1), initial_carry=state.actor_hidden_state, temperature=0.0, ) action = dist.sample(seed=sample_key).squeeze(1) return key, (action, next_hidden_state) def _stochastic_action(self, key, state: SACState): key, sample_key = jax.random.split(key) next_hidden_state, dist = self.actor_network.apply( state.actor_params, jnp.expand_dims(state.obs, 1), mask=jnp.expand_dims(state.done, 1), initial_carry=state.actor_hidden_state, ) action = dist.sample(seed=sample_key).squeeze(1) return key, (action, next_hidden_state) def _random_action(self, key, state: SACState): key, action_key = jax.random.split(key) action_keys = jax.random.split(action_key, self.cfg.num_envs) action = jax.vmap(self.env.action_space(self.env_params).sample)(action_keys) return key, (action, state.actor_hidden_state) def _step( self, carry, _, *, policy: Callable[[Key, "SACState"], tuple[Key, tuple[Array, Array]]], write_to_buffer: bool = True, ): key, state = carry key, policy_key, env_key = jax.random.split(key, 3) key, (action, next_actor_hidden_state) = policy(policy_key, state) num_envs = state.obs.shape[0] env_keys = jax.random.split(env_key, num_envs) next_obs, env_state, reward, done, info = jax.vmap( self.env.step, in_axes=(0, 0, 0, None) )(env_keys, state.env_state, action, self.env_params) transition = Transition( obs=state.obs, # type: ignore prev_done=state.done, # type: ignore action=action, # type: ignore reward=reward, # type: ignore next_obs=next_obs, # type: ignore done=done, # type: ignore info=info, # type: ignore ) buffer_state = state.buffer_state if write_to_buffer: transition = jax.tree.map(lambda x: jnp.expand_dims(x, 1), transition) buffer_state = self.buffer.add(buffer_state, transition) state = state.replace( step=state.step + self.cfg.num_envs, obs=next_obs, done=done, env_state=env_state, buffer_state=buffer_state, actor_hidden_state=next_actor_hidden_state, ) return (key, state), transition
[docs] @partial(jax.jit, static_argnames=["self"]) def init(self, key): key, env_key, actor_key, critic_key, alpha_key = jax.random.split(key, 5) env_keys = jax.random.split(env_key, self.cfg.num_envs) obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))( env_keys, self.env_params ) env_keys = jax.random.split(env_key, self.cfg.num_envs) action = jax.vmap(self.env.action_space(self.env_params).sample)(env_keys) _, _, reward, done, info = jax.vmap(self.env.step, in_axes=(0, 0, 0, None))( env_keys, env_state, action, self.env_params ) actor_carry = self.actor_network.initialize_carry(obs.shape) actor_params = self.actor_network.init( actor_key, observation=jnp.expand_dims(obs, 1), mask=jnp.expand_dims(done, 1), initial_carry=actor_carry, ) actor_optimizer_state = self.actor_optimizer.init(actor_params) actor_hidden_state = self.actor_network.initialize_carry(obs.shape) critic_carry = self.critic_network.initialize_carry(obs.shape) critic_params = self.critic_network.init( critic_key, jnp.expand_dims(obs, 1), jnp.expand_dims(done, 1), initial_carry=critic_carry, action=action, ) critic_target_params = self.critic_network.init( critic_key, jnp.expand_dims(obs, 1), jnp.expand_dims(done, 1), initial_carry=critic_carry, action=action, ) critic_optimizer_state = self.critic_optimizer.init(critic_params) alpha_params = self.alpha_network.init(alpha_key) alpha_optimizer_state = self.alpha_optimizer.init(alpha_params) transition = Transition( obs=obs, prev_done=done, action=action, reward=reward, next_obs=obs, done=done, info=info, ) buffer_state = self.buffer.init(jax.tree.map(lambda x: x[0], transition)) return key, SACState( step=0, actor_hidden_state=actor_hidden_state, env_state=env_state, buffer_state=buffer_state, actor_params=actor_params, critic_params=critic_params, critic_target_params=critic_target_params, alpha_params=alpha_params, actor_optimizer_state=actor_optimizer_state, critic_optimizer_state=critic_optimizer_state, alpha_optimizer_state=alpha_optimizer_state, obs=obs, done=done, )
@partial(jax.jit, static_argnames=["self"]) def _update_alpha(self, key, state: SACState): action_dim = self.env.action_space(self.env_params).shape[0] target_entropy = -self.cfg.target_entropy_scale * action_dim _, dist = self.actor_network.apply(state.actor_params, state.obs, state.done) key, sample_key = jax.random.split(key) actions = dist.sample(seed=sample_key) entropy = (-dist.log_prob(actions)).mean() def alpha_loss_fn(alpha_params): alpha = self.alpha_network.apply(alpha_params) alpha_loss = alpha * (entropy - target_entropy).mean() return alpha_loss, {"alpha": alpha, "alpha_loss": alpha_loss} (_, info), grads = jax.value_and_grad(alpha_loss_fn, has_aux=True)( state.alpha_params ) updates, optimizer_state = self.alpha_optimizer.update( grads, state.alpha_optimizer_state, state.alpha_params ) alpha_params = optax.apply_updates(state.alpha_params, updates) state = state.replace( alpha_params=alpha_params, alpha_optimizer_state=optimizer_state ) return state, info @partial(jax.jit, static_argnames=["self"]) def _update_actor(self, key, state: SACState, batch: Batch): alpha = self.alpha_network.apply(state.alpha_params) mask = jnp.ones_like(batch.prev_done, dtype=bool) if self.cfg.mask: episode_idx = jnp.cumsum(batch.prev_done.astype(jnp.int32), axis=1) terminal = (episode_idx == 1) & batch.prev_done mask = (episode_idx == 0) | terminal def actor_loss_fn(actor_params): _, dist = self.actor_network.apply(actor_params, batch.obs, batch.prev_done) actions, log_probs = dist.sample_and_log_prob(seed=key) _, (q1, q2) = self.critic_network.apply( state.critic_params, batch.obs, batch.prev_done, action=actions ) q = jnp.minimum(q1, q2) actor_loss = (log_probs * alpha - q.squeeze(-1)).mean(where=mask) return actor_loss, {"actor_loss": actor_loss, "entropy": -log_probs.mean()} (_, info), grads = jax.value_and_grad(actor_loss_fn, has_aux=True)( state.actor_params ) updates, actor_optimizer_state = self.actor_optimizer.update( grads, state.actor_optimizer_state, state.actor_params ) actor_params = optax.apply_updates(state.actor_params, updates) state = state.replace( actor_params=actor_params, actor_optimizer_state=actor_optimizer_state ) return state, info @partial(jax.jit, static_argnames=["self"]) def _update_critic(self, key, state: SACState, batch: Batch): _, dist = self.actor_network.apply( state.actor_params, batch.next_obs, batch.done, ) next_actions, next_log_probs = dist.sample_and_log_prob(seed=key) _, (next_q1, next_q2) = self.critic_network.apply( state.critic_target_params, batch.next_obs, batch.done, action=next_actions ) next_q = jnp.minimum(next_q1, next_q2) alpha = self.alpha_network.apply(state.alpha_params) target_q = batch.reward + self.cfg.gamma * (1 - batch.done) * next_q.squeeze(-1) target_q = jax.lax.stop_gradient(target_q) mask = jnp.ones_like(batch.prev_done, dtype=bool) if self.cfg.mask: episode_idx = jnp.cumsum(batch.prev_done.astype(jnp.int32), axis=1) terminal = (episode_idx == 1) & batch.prev_done mask = (episode_idx == 0) | terminal def critic_loss_fn(critic_params): _, (q1, q2) = self.critic_network.apply( critic_params, batch.obs, batch.prev_done, action=batch.action ) q1_error = q1.squeeze(-1) - target_q q2_error = q2.squeeze(-1) - target_q critic_loss = (q1_error**2 + q2_error**2).mean(where=mask) return critic_loss, { "critic_loss": critic_loss, "q1": q1.mean(), "q2": q2.mean(), } (_, info), grads = jax.value_and_grad(critic_loss_fn, has_aux=True)( state.critic_params ) updates, critic_optimizer_state = self.critic_optimizer.update( grads, state.critic_optimizer_state, state.critic_params ) critic_params = optax.apply_updates(state.critic_params, updates) critic_target_params = periodic_incremental_update( critic_params, state.critic_target_params, state.step, self.cfg.target_update_frequency, self.cfg.tau, ) state = state.replace( critic_params=critic_params, critic_target_params=critic_target_params, critic_optimizer_state=critic_optimizer_state, ) return state, info def _update(self, key, state: SACState): key, batch_key, critic_key, actor_key, alpha_key = jax.random.split(key, 5) batch = self.buffer.sample(state.buffer_state, batch_key).experience state, critic_info = self._update_critic(critic_key, state, batch) state, actor_info = self._update_actor(actor_key, state, batch) state, alpha_info = self._update_alpha(alpha_key, state) info = {**critic_info, **actor_info, **alpha_info} return state, info def _update_step(self, carry, _): (key, state), transitions = jax.lax.scan( partial(self._step, policy=self._stochastic_action), carry, length=self.cfg.train_frequency // self.cfg.num_envs, ) key, update_key = jax.random.split(key) state, update_info = self._update(update_key, state) transitions.info.update(update_info) return (key, state), transitions
[docs] @partial(jax.jit, static_argnames=["self", "num_steps"]) def warmup(self, key, state: SACState, num_steps: int) -> tuple[Key, SACState]: (key, state), _ = jax.lax.scan( partial(self._step, policy=self._random_action), (key, state), length=num_steps // self.cfg.num_envs, ) return key, state
[docs] @partial(jax.jit, static_argnames=["self", "num_steps"]) def train(self, key: Key, state: SACState, num_steps: int): (key, state), info = jax.lax.scan( self._update_step, (key, state), length=(num_steps // self.cfg.train_frequency), ) return key, state, info
[docs] @partial(jax.jit, static_argnames=["self", "num_steps"]) def evaluate(self, key: Key, state: SACState, num_steps: int): key, env_key = jax.random.split(key) env_keys = jax.random.split(env_key, self.cfg.num_eval_envs) eval_obs, eval_env_state = jax.vmap(self.env.reset, in_axes=(0, None))( env_keys, self.env_params ) eval_done = jnp.zeros((self.cfg.num_eval_envs,), dtype=jnp.bool_) eval_hidden_state = self.actor_network.initialize_carry(eval_obs.shape) eval_state = state.replace( obs=eval_obs, done=eval_done, env_state=eval_env_state, actor_hidden_state=eval_hidden_state, ) (key, eval_state), transitions = jax.lax.scan( partial( self._step, policy=self._deterministic_action, write_to_buffer=False, ), (key, eval_state), length=num_steps // self.cfg.num_eval_envs, ) return key, transitions