Source code for memorax.algorithms.pqn

from dataclasses import dataclass
from functools import partial
from typing import Any, Callable

import flax.linen as nn
import jax
import jax.numpy as jnp
import lox
import optax
from flax import core, struct

from memorax.utils import Timestep, Transition, utils
from memorax.utils.axes import add_feature_axis, remove_feature_axis, remove_time_axis
from memorax.utils.typing import Array, Environment, EnvParams, EnvState, Key, PyTree


[docs] @struct.dataclass(frozen=True) class PQNConfig: num_envs: int num_steps: int gamma: float td_lambda: float num_minibatches: int update_epochs: int burn_in_length: int = 0 @property def batch_size(self): return self.num_envs * self.num_steps
[docs] @struct.dataclass(frozen=True) class PQNState: step: int update_step: int timestep: Timestep env_state: EnvState params: core.FrozenDict[str, Any] carry: Array optimizer_state: optax.OptState
[docs] @dataclass class PQN: cfg: PQNConfig env: Environment env_params: EnvParams q_network: nn.Module optimizer: optax.GradientTransformation epsilon_schedule: optax.Schedule def __post_init__(self): assert ( self.cfg.update_epochs >= 1 ), f"update_epochs ({self.cfg.update_epochs}) must be >= 1" assert ( self.cfg.batch_size % self.cfg.num_minibatches == 0 ), f"num_envs * num_steps ({self.cfg.batch_size}) must be divisible by num_minibatches ({self.cfg.num_minibatches})" def _greedy_action( self, key: Key, state: PQNState ) -> tuple[PQNState, Array, Array, dict]: obs, done, action, reward = state.timestep.to_sequence() (carry, (q_values, _)), intermediates = self.q_network.apply( state.params, observation=obs, done=done, action=action, reward=reward, initial_carry=state.carry, mutable=["intermediates"], ) q_values = remove_time_axis(q_values) action = jnp.argmax(q_values, axis=-1) state = state.replace(carry=carry) return state, action, q_values, intermediates def _random_action( self, key: Key, state: PQNState ) -> tuple[PQNState, Array, None, dict]: action_key = jax.random.split(key, self.cfg.num_envs) action = jax.vmap(self.env.action_space(self.env_params).sample)(action_key) return state, action, None, {} def _epsilon_greedy_action( self, key: Key, state: PQNState ) -> tuple[PQNState, Array, Array, dict]: random_key, greedy_key, sample_key = jax.random.split(key, 3) state, random_action, _, _ = self._random_action(random_key, state) state, greedy_action, q_values, intermediates = self._greedy_action( greedy_key, state ) epsilon = self.epsilon_schedule(state.step) action = jnp.where( jax.random.uniform(sample_key, greedy_action.shape) < epsilon, random_action, greedy_action, ) return state, action, q_values, intermediates def _step( self, state: PQNState, key: Key, *, policy: Callable ) -> tuple[PQNState, Transition]: action_key, step_key = jax.random.split(key) state, action, q_values, intermediates = policy(action_key, state) num_envs, *_ = state.timestep.obs.shape step_key = jax.random.split(step_key, num_envs) next_obs, env_state, reward, done, info = jax.vmap( self.env.step, in_axes=(0, 0, 0, None) )(step_key, state.env_state, action, self.env_params) intermediates = jax.tree.map( lambda x: jnp.mean(jnp.stack(x)), intermediates.get("intermediates", {}), is_leaf=lambda x: isinstance(x, tuple), ) first = Timestep( obs=state.timestep.obs, action=state.timestep.action, reward=state.timestep.reward, done=state.timestep.done, ) second = Timestep( obs=None, action=action, reward=reward, done=done, ) lox.log({"info": info, "intermediates": intermediates}) transition = Transition( first=first, second=second, aux={"q_values": q_values}, ) next_reward = jnp.asarray(reward, dtype=jnp.float32) state = state.replace( step=state.step + self.cfg.num_envs, timestep=Timestep( obs=next_obs, action=jnp.where(done, jnp.zeros_like(action), action), reward=jnp.where(done, jnp.zeros_like(next_reward), next_reward), done=done, ), env_state=env_state, ) return state, transition def _td_lambda(self, carry: tuple, transition: Transition): lambda_return, next_q_value = carry target_bootstrap = ( transition.second.reward + self.cfg.gamma * (1 - transition.second.done) * next_q_value ) delta = lambda_return - next_q_value lambda_return = target_bootstrap + self.cfg.gamma * self.cfg.td_lambda * delta lambda_return = ( 1.0 - transition.second.done ) * lambda_return + transition.second.done * transition.second.reward q_value = jnp.max(transition.aux["q_values"], axis=-1) return (lambda_return, q_value), lambda_return def _update_epoch(self, carry: tuple, key: Key): state, initial_carry, transitions = carry permutation_key, minibatch_key = jax.random.split(key) batch = (initial_carry, transitions) def shuffle(batch: PyTree): shuffle_time_axis = initial_carry is None num_permutations = self.cfg.num_envs if shuffle_time_axis: batch = ( initial_carry, jax.tree.map( lambda x: x.reshape(-1, 1, *x.shape[2:]), transitions, ), ) num_permutations *= self.cfg.num_steps permutation = jax.random.permutation(permutation_key, num_permutations) minibatches = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0).reshape( self.cfg.num_minibatches, -1, *x.shape[1:] ), tuple(batch), ) return minibatches minibatches = shuffle(batch) minibatch_keys = jax.random.split(minibatch_key, self.cfg.num_minibatches) state, metrics = jax.lax.scan( self._update_minibatch, state, (minibatches, minibatch_keys) ) return (state, initial_carry, transitions), metrics def _update_minibatch(self, state: PQNState, xs: tuple): minibatch, key = xs carry, transitions = minibatch torso_key, dropout_key = jax.random.split(key) carry = utils.burn_in( self.q_network, state.params, transitions.first, carry, self.cfg.burn_in_length, ) transitions = jax.tree.map( lambda x: x[:, self.cfg.burn_in_length :], transitions ) target = transitions.aux["targets"] first_obs, first_done, first_action, first_reward = transitions.first def loss_fn(params: PyTree): _, (q_values, aux) = self.q_network.apply( params, observation=first_obs, done=first_done, action=first_action, reward=first_reward, initial_carry=carry, rngs={"torso": torso_key, "dropout": dropout_key}, ) action = add_feature_axis(transitions.second.action) q_value = jnp.take_along_axis(q_values, action, axis=-1) q_value = remove_feature_axis(q_value) td_error = q_value - target loss = self.q_network.head.loss( q_value, aux, target, transitions=transitions ).mean() return loss, ( q_value.mean(), jnp.abs(td_error).mean(), ) (loss, aux), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) lox.log({"q_network/gradient_norm": optax.global_norm(grads)}) updates, optimizer_state = self.optimizer.update( grads, state.optimizer_state, state.params ) params = optax.apply_updates(state.params, updates) state = state.replace( params=params, optimizer_state=optimizer_state, ) return state, (loss, *aux) def _update_step(self, state: PQNState, key: Key) -> tuple[PQNState, None]: step_key, torso_key, epoch_key = jax.random.split(key, 3) initial_carry = state.carry step_keys = jax.random.split(step_key, self.cfg.num_steps) state, transitions = jax.lax.scan( partial(self._step, policy=self._epsilon_greedy_action), state, step_keys, ) obs, done, action, reward = state.timestep.to_sequence() _, (q_values, _) = self.q_network.apply( state.params, observation=obs, done=done, action=action, reward=reward, initial_carry=state.carry, rngs={"torso": torso_key}, ) q_value = jnp.max(q_values, axis=-1) * (1.0 - done) q_value = remove_time_axis(q_value) _, targets = jax.lax.scan( self._td_lambda, (q_value, q_value), transitions, reverse=True, ) transitions = transitions.replace(aux={**transitions.aux, "targets": targets}) transitions = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), transitions) epoch_keys = jax.random.split(epoch_key, self.cfg.update_epochs) (state, _, transitions), metrics = jax.lax.scan( self._update_epoch, (state, initial_carry, transitions), epoch_keys, ) loss, q_value, td_error = metrics lox.log( { "q_network/loss": loss, "q_network/td_error": td_error, "q_network/q_value": q_value, "training/epsilon": self.epsilon_schedule(state.step), } ) return state.replace(update_step=state.update_step + 1), None
[docs] def init(self, key: Key) -> PQNState: env_key, q_key, torso_key = jax.random.split(key, 3) env_keys = jax.random.split(env_key, self.cfg.num_envs) obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))( env_keys, self.env_params ) action = jnp.zeros( (self.cfg.num_envs, *self.env.action_space(self.env_params).shape), dtype=self.env.action_space(self.env_params).dtype, ) reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32) done = jnp.ones(self.cfg.num_envs, dtype=jnp.bool_) timestep = Timestep( obs=obs, action=action, reward=reward, done=done ).to_sequence() carry = self.q_network.initialize_carry((self.cfg.num_envs, None)) params = self.q_network.init( {"params": q_key, "torso": torso_key}, *timestep, initial_carry=carry, ) optimizer_state = self.optimizer.init(params) return PQNState( step=0, update_step=0, timestep=timestep.from_sequence(), carry=carry, env_state=env_state, params=params, optimizer_state=optimizer_state, )
[docs] def warmup(self, key: Key, state: PQNState, num_steps: int) -> PQNState: return state
[docs] def train( self, key: Key, state: PQNState, num_steps: int, ) -> PQNState: keys = jax.random.split( key, num_steps // (self.cfg.num_steps * self.cfg.num_envs) ) state, _ = jax.lax.scan( self._update_step, state, keys, ) return state
[docs] def evaluate(self, key: Key, state: PQNState, num_steps: int) -> PQNState: reset_key, eval_key = jax.random.split(key) reset_key = jax.random.split(reset_key, self.cfg.num_envs) obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))( reset_key, self.env_params ) action = jnp.zeros( (self.cfg.num_envs, *self.env.action_space(self.env_params).shape), dtype=self.env.action_space(self.env_params).dtype, ) reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32) done = jnp.ones(self.cfg.num_envs, dtype=jnp.bool_) timestep = Timestep(obs=obs, action=action, reward=reward, done=done) carry = self.q_network.initialize_carry((self.cfg.num_envs, None)) state = state.replace(timestep=timestep, carry=carry, env_state=env_state) step_keys = jax.random.split(eval_key, num_steps) state, _ = jax.lax.scan( partial(self._step, policy=self._greedy_action), state, step_keys, ) return state