from dataclasses import dataclass
from functools import partial
from typing import Any, Callable
import flax.linen as nn
import jax
import jax.numpy as jnp
import lox
import optax
from flashbax.utils import get_tree_shape_prefix
from flax import core, struct
from memorax.buffers import compute_importance_weights
from memorax.utils import Timestep, Transition, periodic_incremental_update, utils
from memorax.utils.axes import add_feature_axis, remove_feature_axis, remove_time_axis
from memorax.utils.typing import (
Array,
Buffer,
BufferState,
Carry,
Environment,
EnvParams,
EnvState,
Key,
PyTree,
)
[docs]
@struct.dataclass(frozen=True)
class R2D2Config:
num_envs: int
gamma: float
tau: float
target_update_frequency: int
train_frequency: int
burn_in_length: int = 10
sequence_length: int = 80
n_step: int = 5
priority_exponent: float = 0.9
importance_sampling_exponent: float = 0.6
[docs]
@struct.dataclass(frozen=True)
class R2D2State:
step: int
update_step: int
timestep: Timestep
carry: tuple
env_state: EnvState
params: core.FrozenDict[str, Any]
target_params: core.FrozenDict[str, Any]
optimizer_state: optax.OptState
buffer_state: BufferState
def compute_n_step_returns(
rewards: Array,
dones: Array,
next_q_values: Array,
n_step: int,
gamma: float,
) -> Array:
batch_size, sequence_length = rewards.shape
num_targets = sequence_length - n_step + 1
def compute_target(start_idx: int):
n_step_return = jnp.zeros(batch_size)
discount = 1.0
done = jnp.ones(batch_size)
for i in range(n_step):
idx = start_idx + i
n_step_return = n_step_return + discount * rewards[:, idx] * done
discount = discount * gamma
done = done * (1.0 - dones[:, idx])
bootstrap_idx = start_idx + n_step - 1
n_step_return = (
n_step_return + discount * next_q_values[:, bootstrap_idx] * done
)
return n_step_return
targets = jax.vmap(compute_target)(jnp.arange(num_targets))
targets = targets.T
return targets
[docs]
@dataclass
class R2D2:
cfg: R2D2Config
env: Environment
env_params: EnvParams
q_network: nn.Module
optimizer: optax.GradientTransformation
buffer: Buffer
epsilon_schedule: optax.Schedule
beta_schedule: optax.Schedule
def __post_init__(self):
assert self.cfg.train_frequency >= self.cfg.num_envs, (
f"train_frequency ({self.cfg.train_frequency}) must be >= num_envs ({self.cfg.num_envs})"
)
assert self.cfg.train_frequency % self.cfg.num_envs == 0, (
f"train_frequency ({self.cfg.train_frequency}) must be divisible by num_envs ({self.cfg.num_envs})"
)
assert self.cfg.sequence_length > self.cfg.burn_in_length, (
f"sequence_length ({self.cfg.sequence_length}) must be > burn_in_length ({self.cfg.burn_in_length})"
)
def _greedy_action(
self, key: Key, state: R2D2State
) -> tuple[R2D2State, Array, Array, dict]:
torso_key = key
obs, done, action, reward = state.timestep.to_sequence()
(carry, (q_values, _)), intermediates = self.q_network.apply(
state.params,
observation=obs,
done=done,
action=action,
reward=reward,
initial_carry=state.carry,
rngs={"torso": torso_key},
mutable=["intermediates"],
)
q_values = remove_time_axis(q_values)
action = jnp.argmax(q_values, axis=-1)
state = state.replace(carry=carry)
return state, action, q_values, intermediates
def _random_action(
self, key: Key, state: R2D2State
) -> tuple[R2D2State, Array, None, dict]:
action_key = jax.random.split(key, self.cfg.num_envs)
action = jax.vmap(self.env.action_space(self.env_params).sample)(action_key)
return state, action, None, {}
def _epsilon_greedy_action(
self, key: Key, state: R2D2State
) -> tuple[R2D2State, Array, Array, dict]:
random_key, greedy_key, sample_key = jax.random.split(key, 3)
state, random_action, _, _ = self._random_action(random_key, state)
state, greedy_action, q_values, intermediates = self._greedy_action(greedy_key, state)
epsilon = self.epsilon_schedule(state.step)
action = jnp.where(
jax.random.uniform(sample_key, greedy_action.shape) < epsilon,
random_action,
greedy_action,
)
return state, action, q_values, intermediates
def _step(self, state: R2D2State, key: Key, *, policy: Callable) -> tuple[R2D2State, Transition]:
action_key, step_key = jax.random.split(key)
initial_carry = state.carry
state, action, intermediates = policy(action_key, state)
num_envs, *_ = state.timestep.obs.shape
step_key = jax.random.split(step_key, num_envs)
next_obs, env_state, reward, done, info = jax.vmap(
self.env.step, in_axes=(0, 0, 0, None)
)(step_key, state.env_state, action, self.env_params)
intermediates = jax.tree.map(
lambda x: jnp.mean(jnp.stack(x)),
intermediates.get("intermediates", {}),
is_leaf=lambda x: isinstance(x, tuple),
)
first = Timestep(
obs=state.timestep.obs,
action=state.timestep.action,
reward=state.timestep.reward,
done=state.timestep.done,
)
second = Timestep(
obs=next_obs,
action=action,
reward=reward,
done=done,
)
lox.log({"info": info, "intermediates": intermediates})
transition = Transition(
first=first,
second=second,
carry=initial_carry,
)
buffer_transition = jax.tree.map(lambda x: jnp.expand_dims(x, 1), transition)
buffer_state = self.buffer.add(state.buffer_state, buffer_transition)
next_reward = jnp.asarray(reward, dtype=jnp.float32)
state = state.replace(
step=state.step + self.cfg.num_envs,
timestep=Timestep(
obs=next_obs,
action=jnp.where(done, jnp.zeros_like(action), action),
reward=jnp.where(done, jnp.zeros_like(next_reward), next_reward),
done=done,
),
env_state=env_state,
buffer_state=buffer_state,
)
return state, transition
def _update(self, key: Key, state: R2D2State):
sample_key, torso_key, next_torso_key = jax.random.split(key, 3)
batch = self.buffer.sample(state.buffer_state, sample_key)
experience = batch.experience
initial_carry = None
initial_target_carry = None
if experience.carry is not None:
initial_carry = jax.tree.map(lambda x: x[:, 0], experience.carry)
initial_target_carry = jax.tree.map(lambda x: x[:, 0], experience.carry)
initial_carry = utils.burn_in(self.q_network, state.params, experience.first, initial_carry, self.cfg.burn_in_length)
initial_target_carry = utils.burn_in(self.q_network, state.target_params, experience.second, initial_target_carry, self.cfg.burn_in_length)
experience = jax.tree.map(lambda x: x[:, self.cfg.burn_in_length:], experience)
next_obs, next_done, next_action, next_reward = experience.second
_, (next_target_q_values, _) = self.q_network.apply(
state.target_params,
observation=next_obs,
done=next_done,
action=next_action,
reward=next_reward,
initial_carry=initial_target_carry,
rngs={"torso": next_torso_key},
)
next_target_q_value = jnp.max(next_target_q_values, axis=-1)
_, sequence_length = experience.second.reward.shape
if self.cfg.n_step > 1 and sequence_length >= self.cfg.n_step:
n_step_targets = compute_n_step_returns(
experience.second.reward,
experience.second.done,
next_target_q_value,
self.cfg.n_step,
self.cfg.gamma,
)
_, num_targets = n_step_targets.shape
experience = jax.tree.map(lambda x: x[:, :num_targets], experience)
td_target = n_step_targets
else:
td_target = (
experience.second.reward
+ self.cfg.gamma * (1 - experience.second.done) * next_target_q_value
)
beta = self.beta_schedule(state.step)
add_batch_size, max_length_time_axis = get_tree_shape_prefix(
state.buffer_state.experience, n_axes=2
)
buffer_capacity = add_batch_size * max_length_time_axis
buffer_size = jnp.where(
state.buffer_state.is_full,
buffer_capacity,
state.buffer_state.current_index * add_batch_size,
)
buffer_size = jnp.maximum(buffer_size, 1)
importance_weights = compute_importance_weights(
batch.probabilities, buffer_size, beta
)
importance_weights = importance_weights[:, None]
first_obs, first_done, first_action, first_reward = experience.first
def loss_fn(params: PyTree):
carry, (q_values, aux) = self.q_network.apply(
params,
observation=first_obs,
done=first_done,
action=first_action,
reward=first_reward,
initial_carry=initial_carry,
rngs={"torso": torso_key},
)
action = add_feature_axis(experience.second.action)
q_value = jnp.take_along_axis(q_values, action, axis=-1)
q_value = remove_feature_axis(q_value)
td_error = q_value - td_target
loss = (
importance_weights
* self.q_network.head.loss(
q_value, aux, td_target, transitions=experience
)
).mean()
return loss, (q_value, td_error, carry)
(loss, (q_value, td_error, carry)), grads = jax.value_and_grad(
loss_fn, has_aux=True
)(state.params)
lox.log({"q_network/gradient_norm": optax.global_norm(grads)})
updates, optimizer_state = self.optimizer.update(
grads, state.optimizer_state, state.params
)
params = optax.apply_updates(state.params, updates)
target_params = periodic_incremental_update(
params,
state.target_params,
state.step,
self.cfg.target_update_frequency,
self.cfg.tau,
)
mean_td_error = jnp.abs(td_error).mean(axis=1)
new_priorities = mean_td_error + 1e-6
buffer_state = self.buffer.set_priorities(
state.buffer_state, batch.indices, new_priorities
)
info = {
"q_network/loss": loss,
"q_network/q_value": q_value.mean(),
"q_network/td_error": mean_td_error.mean(),
"training/epsilon": self.epsilon_schedule(state.step),
}
state = state.replace(
params=params,
target_params=target_params,
optimizer_state=optimizer_state,
buffer_state=buffer_state,
)
return state, info
def _update_step(self, state: R2D2State, key: Key) -> tuple[R2D2State, None]:
step_key, update_key = jax.random.split(key)
step_keys = jax.random.split(step_key, self.cfg.train_frequency // self.cfg.num_envs)
state, _ = jax.lax.scan(
partial(self._step, policy=self._epsilon_greedy_action),
state,
step_keys,
)
state, info = self._update(update_key, state)
lox.log(info)
return state.replace(update_step=state.update_step + 1), None
[docs]
def init(self, key: Key) -> R2D2State:
env_key, q_key, torso_key = jax.random.split(key, 3)
env_keys = jax.random.split(env_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
env_keys, self.env_params
)
action_space = self.env.action_space(self.env_params)
action = jnp.zeros(
(self.cfg.num_envs, *action_space.shape), dtype=action_space.dtype
)
reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32)
done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool_)
carry = self.q_network.initialize_carry((self.cfg.num_envs, None))
timestep = Timestep(
obs=obs, action=action, reward=reward, done=done
).to_sequence()
ts_obs, ts_done, ts_action, ts_reward = timestep
params = self.q_network.init(
{"params": q_key, "torso": torso_key},
observation=ts_obs,
done=ts_done,
action=ts_action,
reward=ts_reward,
initial_carry=carry,
)
target_params = params
optimizer_state = self.optimizer.init(params)
timestep = timestep.from_sequence()
transition = Transition(
first=timestep,
second=timestep,
carry=carry,
)
buffer_state = self.buffer.init(jax.tree.map(lambda x: x[0], transition))
return R2D2State(
step=0,
update_step=0,
timestep=timestep,
carry=carry,
env_state=env_state,
params=params,
target_params=target_params,
optimizer_state=optimizer_state,
buffer_state=buffer_state,
)
[docs]
def warmup(self, key: Key, state: R2D2State, num_steps: int) -> R2D2State:
step_keys = jax.random.split(key, num_steps // self.cfg.num_envs)
state, _ = jax.lax.scan(
partial(self._step, policy=self._random_action),
state,
step_keys,
)
return state
[docs]
def train(
self,
key: Key,
state: R2D2State,
num_steps: int,
) -> R2D2State:
num_outer_steps = num_steps // self.cfg.train_frequency
keys = jax.random.split(key, num_outer_steps)
state, _ = jax.lax.scan(
self._update_step,
state,
keys,
)
return state
[docs]
def evaluate(self, key: Key, state: R2D2State, num_steps: int) -> R2D2State:
reset_key, eval_key = jax.random.split(key)
reset_key = jax.random.split(reset_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
reset_key, self.env_params
)
action_space = self.env.action_space(self.env_params)
action = jnp.zeros(
(self.cfg.num_envs, *action_space.shape), dtype=action_space.dtype
)
reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32)
done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool_)
timestep = Timestep(obs=obs, action=action, reward=reward, done=done)
carry = self.q_network.initialize_carry((self.cfg.num_envs, None))
state = state.replace(timestep=timestep, carry=carry, env_state=env_state)
step_keys = jax.random.split(eval_key, num_steps)
state, _ = jax.lax.scan(
partial(self._step, policy=self._greedy_action),
state,
step_keys,
)
return state