from functools import partial
from typing import Any, Callable, Optional
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax import core, struct
from memorax.utils import Transition, periodic_incremental_update
from memorax.utils.typing import (Array, Buffer, BufferState, Environment,
EnvParams, EnvState, Key)
[docs]
@struct.dataclass(frozen=True)
class DQNConfig:
name: str
learning_rate: float
num_envs: int
num_eval_envs: int
buffer_size: int
gamma: float
tau: float
target_network_frequency: int
batch_size: int
start_e: float
end_e: float
exploration_fraction: float
double: bool
learning_starts: int
train_frequency: int
mask: bool
burn_in_length: int = 0
[docs]
@struct.dataclass(frozen=True)
class DQNState:
step: int
obs: Array
done: Array
hidden_state: tuple
env_state: EnvState
params: core.FrozenDict[str, Any]
target_params: core.FrozenDict[str, Any]
optimizer_state: optax.OptState
buffer_state: BufferState
[docs]
@struct.dataclass(frozen=True)
class DQN:
cfg: DQNConfig
env: Environment
env_params: EnvParams
q_network: nn.Module
optimizer: optax.GradientTransformation
buffer: Buffer
epsilon_schedule: optax.Schedule
def _greedy_action(self, key: Key, state: DQNState) -> tuple[Key, DQNState, Array]:
key, memory_key = jax.random.split(key)
hidden_state, q_values = self.q_network.apply(
state.params,
jnp.expand_dims(state.obs, 1),
mask=jnp.expand_dims(state.done, 1),
initial_carry=state.hidden_state,
rngs={"memory": memory_key},
)
action = jnp.argmax(q_values, axis=-1).squeeze(-1)
state = state.replace(hidden_state=hidden_state)
return key, state, action
def _random_action(self, key: Key, state: DQNState) -> tuple[Key, DQNState, Array]:
key, action_key = jax.random.split(key)
action_key = jax.random.split(action_key, self.cfg.num_envs)
action = jax.vmap(self.env.action_space(self.env_params).sample)(action_key)
return key, state, action
def _epsilon_greedy_action(
self, key: Key, state: DQNState
) -> tuple[Key, DQNState, Array]:
key, state, random_action = self._random_action(key, state)
key, state, greedy_action = self._greedy_action(key, state)
key, sample_key = jax.random.split(key)
epsilon = self.epsilon_schedule(state.step)
action = jnp.where(
jax.random.uniform(sample_key, greedy_action.shape) < epsilon,
random_action,
greedy_action,
)
return key, state, action
def _step(
self, carry, _, *, policy: Callable, write_to_buffer: bool = True
) -> tuple[Key, DQNState]:
key, state = carry
key, action_key, step_key = jax.random.split(key, 3)
key, state, action = policy(action_key, state)
num_envs = state.obs.shape[0]
step_key = jax.random.split(step_key, num_envs)
next_obs, env_state, reward, done, info = jax.vmap(
self.env.step, in_axes=(0, 0, 0, None)
)(step_key, state.env_state, action, self.env_params)
transition = Transition(
obs=state.obs, # type: ignore
action=action, # type: ignore
reward=reward, # type: ignore
next_obs=next_obs, # type: ignore
done=done, # type: ignore
info=info, # type: ignore
prev_done=state.done, # type: ignore
)
buffer_state = state.buffer_state
if write_to_buffer:
transition = jax.tree.map(lambda x: jnp.expand_dims(x, 1), transition)
buffer_state = self.buffer.add(state.buffer_state, transition)
state = state.replace(
step=state.step + self.cfg.num_envs,
obs=next_obs, # type: ignore
done=done, # type: ignore
env_state=env_state, # type: ignore
buffer_state=buffer_state,
)
return (key, state), transition
def _update(self, key: Key, state: DQNState) -> tuple[DQNState, Array, Array]:
batch = self.buffer.sample(state.buffer_state, key)
key, memory_key, next_memory_key = jax.random.split(key, 3)
experience = batch.experience
initial_carry = None
initial_target_carry = None
if self.cfg.burn_in_length > 0:
burn_in = jax.tree.map(
lambda x: x[:, : self.cfg.burn_in_length], experience
)
initial_carry, _ = self.q_network.apply(
jax.lax.stop_gradient(state.params),
burn_in.obs,
mask=burn_in.prev_done,
)
initial_carry = jax.lax.stop_gradient(initial_carry)
initial_target_carry, _ = self.q_network.apply(
jax.lax.stop_gradient(state.target_params),
burn_in.next_obs,
mask=burn_in.done,
)
initial_target_carry = jax.lax.stop_gradient(initial_target_carry)
experience = jax.tree.map(
lambda x: x[:, self.cfg.burn_in_length :], experience
)
def make_dqn_target():
_, next_target_q_values = self.q_network.apply(
state.target_params,
experience.next_obs,
mask=experience.done,
initial_carry=initial_target_carry,
rngs={"memory": next_memory_key},
)
return jnp.max(next_target_q_values, axis=-1)
def make_double_dqn_target():
_, next_q_values = self.q_network.apply(
state.target_params,
experience.next_obs,
mask=experience.done,
initial_carry=initial_target_carry,
rngs={"memory": next_memory_key},
)
next_action = jnp.argmax(next_q_values, axis=-1)
next_target_q_value = jnp.take_along_axis(
next_target_q_values, jnp.expand_dims(next_action, -1), axis=-1
).squeeze(-1)
return next_target_q_value
next_target_q_value = (
make_dqn_target() if not self.cfg.double else make_double_dqn_target()
)
td_target = (
experience.reward
+ (1 - experience.done) * self.cfg.gamma * next_target_q_value
)
mask = jnp.ones_like(td_target)
if self.cfg.mask:
episode_idx = jnp.cumsum(experience.prev_done, axis=1)
terminal = (episode_idx == 1) & experience.prev_done
mask *= (episode_idx == 0) | terminal
def loss_fn(params):
hidden_state, q_value = self.q_network.apply(
params,
experience.obs,
mask=experience.prev_done,
initial_carry=initial_carry,
rngs={"memory": memory_key},
)
action = jnp.expand_dims(experience.action, axis=-1)
q_value = jnp.take_along_axis(q_value, action, axis=-1).squeeze(-1)
td_error = q_value - td_target
loss = (jnp.square(td_error)).mean(where=mask.astype(jnp.bool_))
return loss, (q_value, td_error, hidden_state)
(loss, (q_value, td_error, hidden_state)), grads = jax.value_and_grad(
loss_fn, has_aux=True
)(state.params)
updates, optimizer_state = self.optimizer.update(
grads, state.optimizer_state, state.params
)
params = optax.apply_updates(state.params, updates)
target_params = periodic_incremental_update(
params,
state.target_params,
state.step,
self.cfg.target_network_frequency,
self.cfg.tau,
)
buffer_state = state.buffer_state
state = state.replace(
params=params,
target_params=target_params,
optimizer_state=optimizer_state,
buffer_state=buffer_state,
)
return state, loss, q_value.mean()
def _learn(self, carry, _):
(key, state), transitions = jax.lax.scan(
partial(self._step, policy=self._epsilon_greedy_action),
carry,
length=self.cfg.train_frequency // self.cfg.num_envs,
)
key, update_key = jax.random.split(key)
state, loss, q_value = self._update(update_key, state)
transitions.info["losses/loss"] = loss
transitions.info["losses/q_value"] = q_value
return (key, state), transitions.replace(obs=None, next_obs=None)
[docs]
@partial(jax.jit, static_argnames=["self"])
def init(self, key):
key, env_key, q_key, memory_key = jax.random.split(key, 4)
env_keys = jax.random.split(env_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
env_keys, self.env_params
)
action = jax.vmap(self.env.action_space(self.env_params).sample)(env_keys)
_, _, reward, _, info = jax.vmap(self.env.step, in_axes=(0, 0, 0, None))(
env_keys, env_state, action, self.env_params
)
done = jnp.ones(self.cfg.num_envs, dtype=jnp.bool)
carry = self.q_network.initialize_carry(obs.shape)
params = self.q_network.init(
{"params": q_key, "memory": memory_key},
observation=jnp.expand_dims(obs, 1),
mask=jnp.expand_dims(done, 1),
initial_carry=carry,
)
target_params = self.q_network.init(
{"params": q_key, "memory": memory_key},
observation=jnp.expand_dims(obs, 1),
mask=jnp.expand_dims(done, 1),
initial_carry=carry,
)
optimizer_state = self.optimizer.init(params)
transition = Transition(
obs=obs,
action=action,
reward=reward,
next_obs=obs,
done=done,
info=info,
prev_done=done,
) # type: ignore
buffer_state = self.buffer.init(jax.tree.map(lambda x: x[0], transition))
return (
key,
DQNState(
step=0, # type: ignore
obs=obs, # type: ignore
done=done, # type: ignore
hidden_state=carry, # type: ignore
env_state=env_state, # type: ignore
params=params, # type: ignore
target_params=target_params, # type: ignore
optimizer_state=optimizer_state, # type: ignore
buffer_state=buffer_state, # type: ignore
),
)
[docs]
@partial(jax.jit, static_argnames=["self", "num_steps"])
def warmup(self, key: Key, state: DQNState, num_steps: int) -> tuple[Key, DQNState]:
(key, state), _ = jax.lax.scan(
partial(self._step, policy=self._random_action),
(key, state),
length=num_steps // self.cfg.num_envs,
)
return key, state
[docs]
@partial(jax.jit, static_argnames=["self", "num_steps"])
def train(
self,
key: Key,
state: DQNState,
num_steps: int,
):
(
(
key,
state,
),
transitions,
) = jax.lax.scan(
self._learn,
(key, state),
length=(num_steps // self.cfg.train_frequency),
)
return key, state, transitions
[docs]
@partial(jax.jit, static_argnames=["self", "num_steps"])
def evaluate(self, key: Key, state: DQNState, num_steps: int) -> tuple[Key, dict]:
key, reset_key = jax.random.split(key)
reset_key = jax.random.split(reset_key, self.cfg.num_eval_envs)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
reset_key, self.env_params
)
done = jnp.zeros(self.cfg.num_eval_envs, dtype=jnp.bool_)
hidden_state = self.q_network.initialize_carry(obs.shape)
state = state.replace(
obs=obs, done=done, hidden_state=hidden_state, env_state=env_state
) # type: ignore
(key, _), transitions = jax.lax.scan(
partial(self._step, policy=self._greedy_action, write_to_buffer=False),
(key, state),
length=num_steps,
)
return key, transitions.replace(obs=None, next_obs=None)