Source code for memorax.algorithms.dqn

from dataclasses import dataclass
from functools import partial
from typing import Any, Callable

import flax.linen as nn
import jax
import jax.numpy as jnp
import lox
import optax
from flax import core, struct

from memorax.utils import Timestep, Transition, periodic_incremental_update, utils
from memorax.utils.axes import (
    add_feature_axis,
    add_time_axis,
    remove_feature_axis,
    remove_time_axis,
)
from memorax.utils.typing import (
    Array,
    Buffer,
    BufferState,
    Carry,
    Environment,
    EnvParams,
    EnvState,
    Key,
    PyTree,
)


[docs] @struct.dataclass(frozen=True) class DQNConfig: num_envs: int gamma: float tau: float target_update_frequency: int train_frequency: int burn_in_length: int = 0
[docs] @struct.dataclass(frozen=True) class DQNState: step: int update_step: int timestep: Timestep carry: tuple env_state: EnvState params: core.FrozenDict[str, Any] target_params: core.FrozenDict[str, Any] optimizer_state: optax.OptState buffer_state: BufferState
[docs] @dataclass class DQN: cfg: DQNConfig env: Environment env_params: EnvParams q_network: nn.Module optimizer: optax.GradientTransformation buffer: Buffer epsilon_schedule: optax.Schedule def __post_init__(self): assert ( self.cfg.train_frequency >= self.cfg.num_envs ), f"train_frequency ({self.cfg.train_frequency}) must be >= num_envs ({self.cfg.num_envs})" assert ( self.cfg.train_frequency % self.cfg.num_envs == 0 ), f"train_frequency ({self.cfg.train_frequency}) must be divisible by num_envs ({self.cfg.num_envs})" def _greedy_action(self, key: Key, state: DQNState) -> tuple[DQNState, Array, dict]: torso_key = key (carry, (q_values, _)), intermediates = self.q_network.apply( state.params, *state.timestep.to_sequence(), initial_carry=state.carry, rngs={"torso": torso_key}, mutable=["intermediates"], ) action = jnp.argmax(q_values, axis=-1) action = remove_time_axis(action) state = state.replace(carry=carry) return state, action, intermediates def _random_action(self, key: Key, state: DQNState) -> tuple[DQNState, Array, dict]: action_key = jax.random.split(key, self.cfg.num_envs) action = jax.vmap(self.env.action_space(self.env_params).sample)(action_key) return state, action, {} def _epsilon_greedy_action( self, key: Key, state: DQNState ) -> tuple[DQNState, Array, dict]: random_key, greedy_key, sample_key = jax.random.split(key, 3) state, random_action, _ = self._random_action(random_key, state) state, greedy_action, intermediates = self._greedy_action(greedy_key, state) epsilon = self.epsilon_schedule(state.step) action = jnp.where( jax.random.uniform(sample_key, greedy_action.shape) < epsilon, random_action, greedy_action, ) return state, action, intermediates def _step( self, state: DQNState, key: Key, *, policy: Callable ) -> tuple[DQNState, Transition]: action_key, step_key = jax.random.split(key) initial_carry = state.carry state, action, intermediates = policy(action_key, state) num_envs, *_ = state.timestep.obs.shape step_key = jax.random.split(step_key, num_envs) next_obs, env_state, reward, done, info = jax.vmap( self.env.step, in_axes=(0, 0, 0, None) )(step_key, state.env_state, action, self.env_params) intermediates = jax.tree.map( lambda x: jnp.mean(jnp.stack(x)), intermediates.get("intermediates", {}), is_leaf=lambda x: isinstance(x, tuple), ) lox.log({"intermediates": intermediates, "info": info}) first = Timestep( obs=state.timestep.obs, action=state.timestep.action, reward=state.timestep.reward, done=state.timestep.done, ).to_sequence() second = Timestep( obs=next_obs, action=action, reward=reward, done=done, ).to_sequence() transition = Transition( first=first, second=second, carry=jax.tree.map(add_time_axis, initial_carry), ) buffer_state = self.buffer.add(state.buffer_state, transition) next_reward = jnp.asarray(reward, dtype=jnp.float32) state = state.replace( step=state.step + self.cfg.num_envs, timestep=Timestep( obs=next_obs, action=jnp.where(done, jnp.zeros_like(action), action), reward=jnp.where(done, jnp.zeros_like(next_reward), next_reward), done=done, ), env_state=env_state, buffer_state=buffer_state, ) return state, transition def _update(self, key: Key, state: DQNState) -> DQNState: batch_key, torso_key, next_torso_key = jax.random.split(key, 3) batch = self.buffer.sample(state.buffer_state, batch_key) experience = batch.experience experience = jax.tree.map(lambda x: jnp.expand_dims(x, 1), experience) initial_carry = None initial_target_carry = None if experience.carry is not None: initial_carry = jax.tree.map(lambda x: x[:, 0], experience.carry) initial_target_carry = jax.tree.map(lambda x: x[:, 0], experience.carry) initial_carry = utils.burn_in( self.q_network, state.params, experience.first, initial_carry, self.cfg.burn_in_length, ) initial_target_carry = utils.burn_in( self.q_network, state.target_params, experience.second, initial_target_carry, self.cfg.burn_in_length, ) experience = jax.tree.map(lambda x: x[:, self.cfg.burn_in_length :], experience) _, (next_target_q_values, _) = self.q_network.apply( state.target_params, *experience.second, initial_carry=initial_target_carry, rngs={"torso": next_torso_key}, ) next_target_q_value = jnp.max(next_target_q_values, axis=-1) td_target = ( experience.second.reward + self.cfg.gamma * (1 - experience.second.done) * next_target_q_value ) def loss_fn(params: PyTree): carry, (q_values, aux) = self.q_network.apply( params, *experience.first, initial_carry=initial_carry, rngs={"torso": torso_key}, ) action = add_feature_axis(experience.second.action) q_value = jnp.take_along_axis(q_values, action, axis=-1) q_value = remove_feature_axis(q_value) td_error = q_value - td_target loss = self.q_network.head.loss( q_value, aux, td_target, transitions=experience ).mean() return loss, (q_value, td_error, carry) (loss, (q_value, td_error, carry)), grads = jax.value_and_grad( loss_fn, has_aux=True )(state.params) lox.log({"q_network/gradient_norm": optax.global_norm(grads)}) updates, optimizer_state = self.optimizer.update( grads, state.optimizer_state, state.params ) params = optax.apply_updates(state.params, updates) target_params = periodic_incremental_update( params, state.target_params, state.step, self.cfg.target_update_frequency, self.cfg.tau, ) lox.log( { "q_network/loss": loss, "q_network/q_value": q_value.mean(), "q_network/td_error": td_error.mean(), "training/epsilon": self.epsilon_schedule(state.step), } ) state = state.replace( params=params, target_params=target_params, optimizer_state=optimizer_state, ) return state def _update_step(self, state: DQNState, key: Key) -> tuple[DQNState, None]: step_key, update_key = jax.random.split(key) step_keys = jax.random.split( step_key, self.cfg.train_frequency // self.cfg.num_envs ) state, _ = jax.lax.scan( partial(self._step, policy=self._epsilon_greedy_action), state, step_keys, ) state = self._update(update_key, state) return state.replace(update_step=state.update_step + 1), None
[docs] def init(self, key: Key) -> DQNState: env_key, q_key, torso_key = jax.random.split(key, 3) env_keys = jax.random.split(env_key, self.cfg.num_envs) obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))( env_keys, self.env_params ) action_space = self.env.action_space(self.env_params) action = jnp.zeros( (self.cfg.num_envs, *action_space.shape), dtype=action_space.dtype ) reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32) done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool_) carry = self.q_network.initialize_carry((self.cfg.num_envs, None)) timestep = Timestep( obs=obs, action=action, reward=reward, done=done ).to_sequence() params = target_params = self.q_network.init( {"params": q_key, "torso": torso_key}, *timestep, initial_carry=carry, ) optimizer_state = self.optimizer.init(params) timestep = timestep.from_sequence() transition = Transition( first=timestep, second=timestep, carry=carry, ) buffer_state = self.buffer.init(jax.tree.map(lambda x: x[0], transition)) return DQNState( step=0, update_step=0, timestep=timestep, carry=carry, env_state=env_state, params=params, target_params=target_params, optimizer_state=optimizer_state, buffer_state=buffer_state, )
[docs] def warmup(self, key: Key, state: DQNState, num_steps: int) -> DQNState: step_keys = jax.random.split(key, num_steps // self.cfg.num_envs) state, _ = jax.lax.scan( partial(self._step, policy=self._random_action), state, step_keys, ) return state
[docs] def train( self, key: Key, state: DQNState, num_steps: int, ) -> DQNState: keys = jax.random.split(key, num_steps // self.cfg.train_frequency) state, _ = jax.lax.scan( self._update_step, state, keys, ) return state
[docs] def evaluate(self, key: Key, state: DQNState, num_steps: int) -> DQNState: reset_key, eval_key = jax.random.split(key) reset_key = jax.random.split(reset_key, self.cfg.num_envs) obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))( reset_key, self.env_params ) action_space = self.env.action_space(self.env_params) action = jnp.zeros( (self.cfg.num_envs, *action_space.shape), dtype=action_space.dtype ) reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32) done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool_) timestep = Timestep(obs=obs, action=action, reward=reward, done=done) carry = self.q_network.initialize_carry((self.cfg.num_envs, None)) state = state.replace(timestep=timestep, carry=carry, env_state=env_state) step_keys = jax.random.split(eval_key, num_steps) state, _ = jax.lax.scan( partial(self._step, policy=self._greedy_action), state, step_keys, ) return state