from dataclasses import dataclass
from functools import partial
from typing import Any, Callable
import flax.linen as nn
import jax
import jax.numpy as jnp
import lox
from flax import core, struct
from memorax.utils.axes import (
add_time_axis,
remove_feature_axis,
remove_time_axis,
)
from memorax.utils import Timestep, Transition
from memorax.utils.typing import Array, Discrete, Environment, EnvParams, EnvState, Key, Carry, PyTree
[docs]
@struct.dataclass(frozen=True)
class StreamACConfig:
num_envs: int
gamma: float
trace_lambda: float
actor_lr: float
critic_lr: float
actor_kappa: float = 3.0
critic_kappa: float = 2.0
entropy_coefficient: float = 0.01
adaptive: bool = False
beta2: float = 0.999
eps: float = 1e-8
[docs]
@struct.dataclass(frozen=True)
class StreamACState:
step: int
update_step: int
timestep: Timestep
env_state: EnvState
actor_params: core.FrozenDict[str, Any]
actor_traces: core.FrozenDict[str, Any]
actor_v: core.FrozenDict[str, Any]
actor_carry: Array
critic_params: core.FrozenDict[str, Any]
critic_traces: core.FrozenDict[str, Any]
critic_v: core.FrozenDict[str, Any]
critic_carry: Array
[docs]
@dataclass
class StreamAC:
cfg: StreamACConfig
env: Environment
env_params: EnvParams
actor_network: nn.Module
critic_network: nn.Module
def _deterministic_action(
self, key: Key, state: StreamACState
) -> tuple[StreamACState, Array, Array, None, dict]:
obs, done, action, reward = state.timestep.to_sequence()
(actor_carry, (probs, _)), intermediates = self.actor_network.apply(
state.actor_params,
observation=obs,
action=action,
reward=reward,
done=done,
initial_carry=state.actor_carry,
mutable=["intermediates"],
)
action = (
jnp.argmax(probs.logits, axis=-1)
if isinstance(self.env.action_space(self.env_params), Discrete)
else probs.mode()
)
log_prob = probs.log_prob(action)
action = remove_time_axis(action)
log_prob = remove_time_axis(log_prob)
state = state.replace(actor_carry=actor_carry)
return state, action, log_prob, None, intermediates
def _stochastic_action(
self, key: Key, state: StreamACState
) -> tuple[StreamACState, Array, Array, Array, dict]:
action_key, actor_torso_key, critic_torso_key = jax.random.split(key, 3)
obs, done, ts_action, reward = state.timestep.to_sequence()
(actor_carry, (probs, _)), intermediates = self.actor_network.apply(
state.actor_params,
observation=obs,
action=ts_action,
reward=reward,
done=done,
initial_carry=state.actor_carry,
rngs={"torso": actor_torso_key},
mutable=["intermediates"],
)
action, log_prob = probs.sample_and_log_prob(seed=action_key)
critic_carry, (value, _) = self.critic_network.apply(
state.critic_params,
observation=obs,
action=ts_action,
reward=reward,
done=done,
initial_carry=state.critic_carry,
rngs={"torso": critic_torso_key},
)
action = remove_time_axis(action)
log_prob = remove_time_axis(log_prob)
value = remove_time_axis(value)
value = remove_feature_axis(value)
state = state.replace(actor_carry=actor_carry, critic_carry=critic_carry)
return state, action, log_prob, value, intermediates
def _step(
self, state: StreamACState, key: Key, *, policy: Callable
) -> tuple[StreamACState, Transition]:
action_key, step_key = jax.random.split(key)
state, action, log_prob, value, intermediates = policy(action_key, state)
num_envs, *_ = state.timestep.obs.shape
step_keys = jax.random.split(step_key, num_envs)
next_obs, env_state, reward, done, info = jax.vmap(
self.env.step, in_axes=(0, 0, 0, None)
)(step_keys, state.env_state, action, self.env_params)
intermediates = jax.tree.map(
lambda x: jnp.mean(jnp.stack(x)),
intermediates.get("intermediates", {}),
is_leaf=lambda x: isinstance(x, tuple),
)
broadcast_dims = tuple(range(state.timestep.done.ndim, state.timestep.action.ndim))
first = Timestep(
obs=state.timestep.obs,
action=state.timestep.action,
reward=state.timestep.reward,
done=state.timestep.done,
)
second = Timestep(obs=None, action=action, reward=reward, done=done)
lox.log({"info": info, "intermediates": intermediates})
transition = Transition(
first=first,
second=second,
aux={"log_prob": log_prob, "value": value},
)
next_reward = jnp.asarray(reward, dtype=jnp.float32)
state = state.replace(
step=state.step + self.cfg.num_envs,
timestep=Timestep(
obs=next_obs,
action=jnp.where(
jnp.expand_dims(done, axis=broadcast_dims),
jnp.zeros_like(action),
action,
),
reward=jnp.where(done, jnp.zeros_like(next_reward), next_reward),
done=done,
),
env_state=env_state,
)
return state, transition
def _obgd_update(self, traces: PyTree, v: PyTree, td_error: Array, lr: float, kappa: float, step: int):
beta2 = self.cfg.beta2
eps = self.cfg.eps
def _broadcast_delta(td_error, z):
n_trailing = z.ndim - 1
return td_error[(slice(None),) + (None,) * n_trailing]
# Update second moment: v <- beta2*v + (1-beta2)*(delta*z)^2
new_v = jax.tree.map(
lambda vi, z: beta2 * vi + (1 - beta2) * jnp.square(_broadcast_delta(td_error, z) * z),
v, traces,
)
if self.cfg.adaptive:
# Bias-corrected v_hat = v / (1 - beta2^t)
v_hat = jax.tree.map(lambda vi: vi / (1.0 - beta2 ** step), new_v)
# z_sum over normalised traces: sum|z / sqrt(v_hat + eps)|
norm_leaves = jax.tree.leaves(jax.tree.map(
lambda z, vh: jnp.abs(z) / (jnp.sqrt(vh) + eps), traces, v_hat,
))
z_sum = sum(
jnp.sum(z, axis=tuple(range(1, z.ndim))) for z in norm_leaves
)
else:
v_hat = None
z_leaves = jax.tree.leaves(traces)
z_sum = sum(
jnp.sum(jnp.abs(z), axis=tuple(range(1, z.ndim))) for z in z_leaves
)
delta_bar = jnp.maximum(jnp.abs(td_error), 1.0)
step_size = lr / jnp.maximum(1.0, delta_bar * z_sum * lr * kappa)
if self.cfg.adaptive:
def compute_update(z: Array, vh: Array):
n_trailing = z.ndim - 1
ss = step_size[(slice(None),) + (None,) * n_trailing]
delta = td_error[(slice(None),) + (None,) * n_trailing]
return (ss * delta * z / (jnp.sqrt(vh) + eps)).mean(axis=0)
updates = jax.tree.map(compute_update, traces, v_hat)
else:
def compute_update(z: Array):
n_trailing = z.ndim - 1
ss = step_size[(slice(None),) + (None,) * n_trailing]
delta = td_error[(slice(None),) + (None,) * n_trailing]
return (ss * delta * z).mean(axis=0)
updates = jax.tree.map(compute_update, traces)
return updates, new_v
def _update_step(self, state: StreamACState, key: Key) -> tuple[StreamACState, None]:
action_key, step_key, actor_torso_key, critic_torso_key = jax.random.split(key, 4)
obs, done, ts_action, reward = state.timestep.to_sequence()
(actor_carry, (probs, _)), intermediates = self.actor_network.apply(
state.actor_params,
observation=obs,
action=ts_action,
reward=reward,
done=done,
initial_carry=state.actor_carry,
rngs={"torso": actor_torso_key},
mutable=["intermediates"],
)
action, log_prob = probs.sample_and_log_prob(seed=action_key)
entropy = remove_time_axis(probs.entropy()).mean()
action = remove_time_axis(action)
log_prob = remove_time_axis(log_prob)
critic_carry, (value, _) = self.critic_network.apply(
state.critic_params,
observation=obs,
action=ts_action,
reward=reward,
done=done,
initial_carry=state.critic_carry,
rngs={"torso": critic_torso_key},
)
value = remove_time_axis(value)
value = remove_feature_axis(value)
num_envs, *_ = state.timestep.obs.shape
step_key = jax.random.split(step_key, num_envs)
next_obs, env_state, next_reward, next_done, info = jax.vmap(
self.env.step, in_axes=(0, 0, 0, None)
)(step_key, state.env_state, action, self.env_params)
next_obs_s, next_done_s, next_action_s, next_reward_s = Timestep(
obs=next_obs, action=action, reward=next_reward, done=next_done
).to_sequence()
_, (next_value, _) = self.critic_network.apply(
jax.lax.stop_gradient(state.critic_params),
observation=next_obs_s,
action=next_action_s,
reward=next_reward_s,
done=next_done_s,
initial_carry=jax.lax.stop_gradient(critic_carry),
)
next_value = remove_time_axis(next_value)
next_value = remove_feature_axis(next_value)
gamma = self.cfg.gamma
td_error = next_reward + gamma * (1 - next_done) * next_value - value
initial_actor_carry = jax.lax.stop_gradient(state.actor_carry)
initial_critic_carry = jax.lax.stop_gradient(state.critic_carry)
def critic_loss_fn(params: PyTree):
_, (v, _) = self.critic_network.apply(
params,
observation=obs,
action=ts_action,
reward=reward,
done=done,
initial_carry=initial_critic_carry,
)
return remove_feature_axis(remove_time_axis(v))
def actor_loss_fn(params: PyTree):
_, (dist, _) = self.actor_network.apply(
params,
observation=obs,
action=ts_action,
reward=reward,
done=done,
initial_carry=initial_actor_carry,
)
log_p = remove_time_axis(dist.log_prob(add_time_axis(action)))
entropy = remove_time_axis(dist.entropy())
return log_p + self.cfg.entropy_coefficient * jnp.sign(td_error) * entropy
critic_grads = jax.jacobian(critic_loss_fn)(state.critic_params)
actor_grads = jax.jacobian(actor_loss_fn)(state.actor_params)
trace_decay = gamma * self.cfg.trace_lambda
def update_trace(z: Array, g: Array):
n_trailing = z.ndim - 1
not_done = (1 - state.timestep.done)[(slice(None),) + (None,) * n_trailing]
return trace_decay * not_done * z + g
critic_traces = jax.tree.map(update_trace, state.critic_traces, critic_grads)
actor_traces = jax.tree.map(update_trace, state.actor_traces, actor_grads)
current_step = state.update_step + 1
critic_updates, critic_v = self._obgd_update(critic_traces, state.critic_v, td_error, self.cfg.critic_lr, self.cfg.critic_kappa, current_step)
actor_updates, actor_v = self._obgd_update(actor_traces, state.actor_v, td_error, self.cfg.actor_lr, self.cfg.actor_kappa, current_step)
critic_params = jax.tree.map(lambda p, u: p + u, state.critic_params, critic_updates)
actor_params = jax.tree.map(lambda p, u: p + u, state.actor_params, actor_updates)
intermediates = jax.tree.map(
lambda x: jnp.mean(jnp.stack(x)),
intermediates.get("intermediates", {}),
is_leaf=lambda x: isinstance(x, tuple),
)
broadcast_dims = tuple(range(state.timestep.done.ndim, state.timestep.action.ndim))
first = Timestep(
obs=state.timestep.obs,
action=state.timestep.action,
reward=state.timestep.reward,
done=state.timestep.done,
)
second = Timestep(obs=None, action=action, reward=next_reward, done=next_done)
lox.log({
"info": info,
"intermediates": intermediates,
"critic/td_error": td_error.mean(),
"actor/entropy": entropy,
"critic/value": value.mean(),
})
next_reward_f = jnp.asarray(next_reward, dtype=jnp.float32)
state = state.replace(
step=state.step + self.cfg.num_envs,
update_step=current_step,
timestep=Timestep(
obs=next_obs,
action=jnp.where(
jnp.expand_dims(next_done, axis=broadcast_dims),
jnp.zeros_like(action),
action,
),
reward=jnp.where(
next_done, jnp.zeros_like(next_reward_f), next_reward_f
),
done=next_done,
),
env_state=env_state,
actor_params=actor_params,
actor_traces=actor_traces,
actor_v=actor_v,
actor_carry=actor_carry,
critic_params=critic_params,
critic_traces=critic_traces,
critic_v=critic_v,
critic_carry=critic_carry,
)
return state, None
[docs]
def init(self, key: Key) -> StreamACState:
(
env_key,
actor_key,
actor_torso_key,
actor_dropout_key,
critic_key,
critic_torso_key,
critic_dropout_key,
) = jax.random.split(key, 7)
env_keys = jax.random.split(env_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
env_keys, self.env_params
)
action = jnp.zeros(
(self.cfg.num_envs, *self.env.action_space(self.env_params).shape),
dtype=self.env.action_space(self.env_params).dtype,
)
reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32)
done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool_)
timestep = Timestep(obs=obs, action=action, reward=reward, done=done).to_sequence()
actor_carry = self.actor_network.initialize_carry((self.cfg.num_envs, None))
critic_carry = self.critic_network.initialize_carry((self.cfg.num_envs, None))
ts_obs, ts_done, ts_action, ts_reward = timestep
actor_params = self.actor_network.init(
{
"params": actor_key,
"torso": actor_torso_key,
"dropout": actor_dropout_key,
},
observation=ts_obs,
action=ts_action,
reward=ts_reward,
done=ts_done,
initial_carry=actor_carry,
)
critic_params = self.critic_network.init(
{
"params": critic_key,
"torso": critic_torso_key,
"dropout": critic_dropout_key,
},
observation=ts_obs,
action=ts_action,
reward=ts_reward,
done=ts_done,
initial_carry=critic_carry,
)
actor_traces = jax.tree.map(
lambda p: jnp.zeros((self.cfg.num_envs, *p.shape)), actor_params
)
critic_traces = jax.tree.map(
lambda p: jnp.zeros((self.cfg.num_envs, *p.shape)), critic_params
)
actor_v = jax.tree.map(jnp.zeros_like, actor_traces)
critic_v = jax.tree.map(jnp.zeros_like, critic_traces)
return StreamACState(
step=0,
update_step=0,
timestep=timestep.from_sequence(),
env_state=env_state,
actor_params=actor_params,
actor_traces=actor_traces,
actor_v=actor_v,
actor_carry=actor_carry,
critic_params=critic_params,
critic_traces=critic_traces,
critic_v=critic_v,
critic_carry=critic_carry,
)
[docs]
def warmup(self, key: Key, state: StreamACState, num_steps: int) -> StreamACState:
return state
[docs]
def train(self, key: Key, state: StreamACState, num_steps: int) -> StreamACState:
keys = jax.random.split(key, num_steps // self.cfg.num_envs)
state, _ = jax.lax.scan(
self._update_step,
state,
keys,
)
return state
[docs]
def evaluate(self, key: Key, state: StreamACState, num_steps: int) -> StreamACState:
reset_key, eval_key = jax.random.split(key)
reset_key = jax.random.split(reset_key, self.cfg.num_envs)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
reset_key, self.env_params
)
action = jnp.zeros(
(self.cfg.num_envs, *self.env.action_space(self.env_params).shape),
dtype=self.env.action_space(self.env_params).dtype,
)
reward = jnp.zeros((self.cfg.num_envs,), dtype=jnp.float32)
done = jnp.ones((self.cfg.num_envs,), dtype=jnp.bool_)
timestep = Timestep(obs=obs, action=action, reward=reward, done=done)
initial_actor_carry = self.actor_network.initialize_carry((self.cfg.num_envs, None))
initial_critic_carry = self.critic_network.initialize_carry((self.cfg.num_envs, None))
state = state.replace(
timestep=timestep,
actor_carry=initial_actor_carry,
critic_carry=initial_critic_carry,
env_state=env_state,
)
step_keys = jax.random.split(eval_key, num_steps)
state, _ = jax.lax.scan(
partial(self._step, policy=self._deterministic_action),
state,
step_keys,
)
return state