Source code for memorax.algorithms.pqn

from functools import partial
from typing import Any, Callable

import flax.linen as nn
import gymnax
import jax
import jax.numpy as jnp
import optax
from flax import core, struct

from memorax.networks import SequenceModelWrapper
from memorax.networks.sequence_models.utils import (
    add_feature_axis,
    remove_feature_axis,
    remove_time_axis,
)
from memorax.utils import Timestep, Transition
from memorax.utils.typing import Array, Environment, EnvParams, EnvState, Key


[docs] @struct.dataclass(frozen=True) class PQNConfig: name: str num_envs: int num_eval_envs: int num_steps: int gamma: float td_lambda: float num_minibatches: int update_epochs: int burn_in_length: int = 0 @property def batch_size(self): return self.num_envs * self.num_steps
[docs] @struct.dataclass(frozen=True) class PQNState: step: int timestep: Timestep env_state: EnvState params: core.FrozenDict[str, Any] hidden_state: Array optimizer_state: optax.OptState
[docs] @struct.dataclass(frozen=True) class PQN: cfg: PQNConfig env: Environment env_params: EnvParams q_network: nn.Module optimizer: optax.GradientTransformation epsilon_schedule: optax.Schedule def _greedy_action( self, key: Key, state: PQNState ) -> tuple[Key, PQNState, Array, Array]: timestep = state.timestep.to_sequence() hidden_state, q_values = self.q_network.apply( state.params, observation=timestep.obs, mask=timestep.done, action=timestep.action, reward=add_feature_axis(timestep.reward), done=timestep.done, initial_carry=state.hidden_state, ) q_values = remove_time_axis(q_values) action = jnp.argmax(q_values, axis=-1) state = state.replace(hidden_state=hidden_state) return key, state, action, q_values def _random_action( self, key: Key, state: PQNState ) -> tuple[Key, PQNState, Array, None]: key, action_key = jax.random.split(key) action_key = jax.random.split(action_key, self.cfg.num_envs) action = jax.vmap(self.env.action_space(self.env_params).sample)(action_key) return key, state, action, None def _epsilon_greedy_action( self, key: Key, state: PQNState ) -> tuple[Key, PQNState, Array, Array]: key, state, random_action, _ = self._random_action(key, state) key, state, greedy_action, q_values = self._greedy_action(key, state) key, sample_key = jax.random.split(key) epsilon = self.epsilon_schedule(state.step) action = jnp.where( jax.random.uniform(sample_key, greedy_action.shape) < epsilon, random_action, greedy_action, ) return key, state, action, q_values def _step( self, carry, _, *, policy: Callable ) -> tuple[tuple[Key, PQNState], Transition]: key, state = carry key, action_key, step_key = jax.random.split(key, 3) key, state, action, q_values = policy(action_key, state) num_envs, *_ = state.timestep.obs.shape step_key = jax.random.split(step_key, num_envs) next_obs, env_state, reward, done, info = jax.vmap( self.env.step, in_axes=(0, 0, 0, None) )(step_key, state.env_state, action, self.env_params) prev_action = jnp.where( state.timestep.done, jnp.zeros_like(state.timestep.action), state.timestep.action, ) prev_reward = jnp.where( state.timestep.done, jnp.zeros_like(state.timestep.reward), state.timestep.reward, ) transition = Transition( obs=state.timestep.obs, # type: ignore action=action, # type: ignore reward=reward, # type: ignore done=done, # type: ignore info=info, # type: ignore value=q_values, # type: ignore next_obs=next_obs, # type: ignore prev_action=prev_action, # type: ignore prev_reward=prev_reward, # type: ignore prev_done=state.timestep.done, # type: ignore ) state = state.replace( step=state.step + self.cfg.num_envs, timestep=Timestep(obs=next_obs, action=action, reward=reward, done=done), # type: ignore env_state=env_state, # type: ignore ) return (key, state), transition def _td_lambda(self, carry, transition): lambda_return, next_q_value = carry target_bootstrap = ( transition.reward + self.cfg.gamma * (1.0 - transition.done) * next_q_value ) delta = lambda_return - next_q_value lambda_return = target_bootstrap + self.cfg.gamma * self.cfg.td_lambda * delta lambda_return = ( 1.0 - transition.done ) * lambda_return + transition.done * transition.reward q_value = jnp.max(transition.value, axis=-1) return (lambda_return, q_value), lambda_return def _update_epoch(self, carry, _): key, state, initial_hidden_state, transitions, lambda_targets = carry key, permutation_key = jax.random.split(key) batch = (initial_hidden_state, transitions, lambda_targets) def shuffle(batch): shuffle_time_axis = isinstance(self.q_network, SequenceModelWrapper) num_permutations = self.cfg.num_envs if shuffle_time_axis: batch = ( initial_hidden_state, *jax.tree.map( lambda x: x.reshape(-1, 1, *x.shape[2:]), (transitions, lambda_targets), ), ) num_permutations *= self.cfg.num_steps permutation = jax.random.permutation(permutation_key, num_permutations) minibatches = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0).reshape( self.cfg.num_minibatches, -1, *x.shape[1:] ), tuple(batch), ) return minibatches minibatches = shuffle(batch) (key, state), (loss, q_value) = jax.lax.scan( self._update_minibatch, (key, state), xs=minibatches ) return (key, state, initial_hidden_state, transitions, lambda_targets), ( loss, q_value, ) def _update_minibatch( self, carry, minibatch ) -> tuple[tuple[PQNState, Array], tuple[Array, Array]]: key, state = carry hidden_state, transitions, target = minibatch key, memory_key, dropout_key = jax.random.split(key, 3) if self.cfg.burn_in_length > 0: burn_in = jax.tree.map( lambda x: x[:, : self.cfg.burn_in_length], transitions ) hidden_state, _ = self.q_network.apply( jax.lax.stop_gradient(state.params), observation=burn_in.obs, mask=burn_in.prev_done, action=burn_in.prev_action, reward=add_feature_axis(burn_in.prev_reward), done=burn_in.prev_done, initial_carry=hidden_state, ) hidden_state = jax.lax.stop_gradient(hidden_state) transitions = jax.tree.map( lambda x: x[:, self.cfg.burn_in_length :], transitions ) target = target[:, self.cfg.burn_in_length :] def loss_fn(params): _, q_values = self.q_network.apply( params, observation=transitions.obs, mask=transitions.prev_done, action=transitions.prev_action, reward=add_feature_axis(transitions.prev_reward), done=transitions.prev_done, initial_carry=hidden_state, rngs={"memory": memory_key, "dropout": dropout_key}, ) action = add_feature_axis(transitions.action) q_value = jnp.take_along_axis(q_values, action, axis=-1) q_value = remove_feature_axis(q_value) loss = 0.5 * jnp.square(q_value - target).mean() return loss, q_value.mean() (loss, q_value), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) updates, optimizer_state = self.optimizer.update( grads, state.optimizer_state, state.params ) params = optax.apply_updates(state.params, updates) state = state.replace( params=params, optimizer_state=optimizer_state, ) return (key, state), (loss, q_value) def _learn( self, carry: tuple[Key, PQNState], _ ) -> tuple[tuple[Key, PQNState], dict]: key, state = carry initial_hidden_state = state.hidden_state (key, state), transitions = jax.lax.scan( partial(self._step, policy=self._epsilon_greedy_action), (key, state), length=self.cfg.num_steps, ) key, memory_key, dropout_key = jax.random.split(key, 3) timestep = state.timestep.to_sequence() _, q_values = self.q_network.apply( state.params, observation=timestep.obs, mask=timestep.done, action=timestep.action, reward=add_feature_axis(timestep.reward), done=timestep.done, initial_carry=state.hidden_state, rngs={"memory": memory_key, "dropout": dropout_key}, ) q_value = jnp.max(q_values, axis=-1) * (1.0 - timestep.done) q_value = remove_time_axis(q_value) _, targets = jax.lax.scan( self._td_lambda, (q_value, q_value), transitions, reverse=True, ) transitions = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), transitions) targets = jnp.swapaxes(targets, 0, 1) (key, state, _, transitions, _), (loss, q_value) = jax.lax.scan( self._update_epoch, (key, state, initial_hidden_state, transitions, targets), None, self.cfg.update_epochs, ) transitions.info["losses/loss"] = loss transitions.info["losses/q_value"] = q_value return (key, state), transitions.replace(obs=None, next_obs=None)
[docs] @partial(jax.jit, static_argnames=["self"]) def init(self, key) -> tuple[Key, PQNState, Array, gymnax.EnvState]: key, env_key, q_key, memory_key = jax.random.split(key, 4) env_keys = jax.random.split(env_key, self.cfg.num_envs) obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))( env_keys, self.env_params ) action = jnp.zeros( (self.cfg.num_envs, *self.env.action_space(self.env_params).shape), dtype=self.env.action_space(self.env_params).dtype, ) reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32) done = jnp.zeros(self.cfg.num_envs, dtype=jnp.bool) timestep = Timestep( obs=obs, action=action, reward=reward, done=done ).to_sequence() hidden_state = self.q_network.initialize_carry((self.cfg.num_envs, None)) params = self.q_network.init( {"params": q_key, "memory": memory_key}, observation=timestep.obs, mask=timestep.done, action=timestep.action, reward=add_feature_axis(timestep.reward), done=timestep.done, initial_carry=hidden_state, ) optimizer_state = self.optimizer.init(params) return ( key, PQNState( step=0, # type: ignore timestep=timestep.from_sequence(), # type: ignore hidden_state=hidden_state, # type: ignore env_state=env_state, # type: ignore params=params, # type: ignore optimizer_state=optimizer_state, # type: ignore ), )
[docs] @partial(jax.jit, static_argnames=["self", "num_steps"]) def warmup(self, key: Key, state: PQNState, num_steps: int) -> tuple[Key, PQNState]: return key, state
[docs] @partial(jax.jit, static_argnames=["self", "num_steps"]) def train( self, key: Key, state: PQNState, num_steps: int, ) -> tuple[Key, PQNState, dict]: (key, state), transitions = jax.lax.scan( self._learn, (key, state), length=(num_steps // (self.cfg.num_steps * self.cfg.num_envs)), ) transitions = jax.tree.map( lambda x: x.swapaxes(1, 2).reshape((-1,) + x.shape[2:]), transitions ) return key, state, transitions
[docs] @partial(jax.jit, static_argnames=["self", "num_steps"]) def evaluate(self, key: Key, state: PQNState, num_steps: int) -> tuple[Key, dict]: key, reset_key = jax.random.split(key) reset_key = jax.random.split(reset_key, self.cfg.num_eval_envs) obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))( reset_key, self.env_params ) action = jnp.zeros( (self.cfg.num_eval_envs, *self.env.action_space(self.env_params).shape), dtype=self.env.action_space(self.env_params).dtype, ) reward = jnp.zeros((self.cfg.num_eval_envs,), dtype=jnp.float32) done = jnp.zeros(self.cfg.num_eval_envs, dtype=jnp.bool) timestep = Timestep(obs=obs, action=action, reward=reward, done=done) hidden_state = self.q_network.initialize_carry((self.cfg.num_eval_envs, None)) state = state.replace( timestep=timestep, hidden_state=hidden_state, env_state=env_state ) (key, *_), transitions = jax.lax.scan( partial(self._step, policy=self._greedy_action), (key, state), length=num_steps, ) return key, transitions.replace(obs=None, next_obs=None)