from typing import Callable
import distrax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen.initializers import constant
from memorax.utils.axes import add_feature_axis, remove_feature_axis
from memorax.utils.typing import Array, PyTree
[docs]
class DiscreteQNetwork(nn.Module):
action_dim: int
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]:
q_values = nn.Dense(
self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
return q_values, {}
[docs]
def loss(
self, output: Array, aux: dict, targets: Array, **kwargs
) -> Array:
return 0.5 * jnp.square(output - targets)
[docs]
class ContinuousQNetwork(nn.Module):
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(
self, x: Array, *, action: Array, **kwargs
) -> tuple[Array, dict]:
q_values = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init)(
jnp.concatenate([x, action], axis=-1)
)
return jnp.squeeze(q_values, -1), {}
[docs]
def loss(
self, output: Array, aux: dict, targets: Array, **kwargs
) -> Array:
return 0.5 * jnp.square(output - targets)
[docs]
class TwinContinuousQNetwork(nn.Module):
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(
self, x: Array, *, action: Array, **kwargs
) -> tuple[tuple[Array, Array], dict]:
inp = jnp.concatenate([x, action], axis=-1)
q1 = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init, name="q1")(inp)
q2 = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init, name="q2")(inp)
return (jnp.squeeze(q1, -1), jnp.squeeze(q2, -1)), {}
[docs]
def loss(
self, output: Array, aux: dict, targets: Array, **kwargs
) -> Array:
return 0.5 * jnp.square(output - targets)
[docs]
class VNetwork(nn.Module):
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]:
v_value = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init)(x)
return v_value, {}
[docs]
def loss(
self, output: Array, aux: dict, targets: Array, **kwargs
) -> Array:
return 0.5 * jnp.square(output - targets)
[docs]
class HLGaussVNetwork(nn.Module):
"""HL-Gauss value head with two-hot cross-entropy loss."""
num_bins: int = 101
v_min: float = -10.0
v_max: float = 10.0
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
def setup(self):
self.bin_width = (self.v_max - self.v_min) / (self.num_bins - 1)
self.bin_centers = jnp.linspace(self.v_min, self.v_max, self.num_bins)
[docs]
@nn.compact
def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]:
logits = nn.Dense(
self.num_bins, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
probs = jax.nn.softmax(logits, axis=-1)
value = jnp.sum(probs * self.bin_centers, axis=-1, keepdims=True)
return value, {"logits": logits}
[docs]
@nn.nowrap
def loss(
self, output: Array, aux: dict, targets: Array, **kwargs
) -> Array:
"""Two-hot cross-entropy loss."""
logits = aux["logits"]
targets = remove_feature_axis(targets)
targets = jnp.clip(targets, self.v_min, self.v_max)
lower_idx = ((targets - self.v_min) / self.bin_width).astype(jnp.int32)
lower_idx = jnp.clip(lower_idx, 0, self.num_bins - 2)
upper_idx = lower_idx + 1
upper_weight = (
targets - (self.v_min + lower_idx * self.bin_width)
) / self.bin_width
lower_weight = 1.0 - upper_weight
log_probs = jax.nn.log_softmax(logits, axis=-1)
lower_log_prob = jnp.take_along_axis(
log_probs, lower_idx[..., None].astype(jnp.int32), axis=-1
).squeeze(-1)
upper_log_prob = jnp.take_along_axis(
log_probs, upper_idx[..., None].astype(jnp.int32), axis=-1
).squeeze(-1)
loss = -(lower_weight * lower_log_prob + upper_weight * upper_log_prob)
return loss
[docs]
class C51QNetwork(nn.Module):
action_dim: int
num_atoms: int = 51
v_min: float = -10.0
v_max: float = 10.0
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
def setup(self):
self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
self.atoms = jnp.linspace(self.v_min, self.v_max, self.num_atoms)
[docs]
@nn.compact
def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]:
logits = nn.Dense(
self.action_dim * self.num_atoms,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)(x)
batch_shape = logits.shape[:-1]
logits = logits.reshape(*batch_shape, self.action_dim, self.num_atoms)
probs = jax.nn.softmax(logits, axis=-1)
q_values = jnp.sum(probs * self.atoms, axis=-1)
return q_values, {"logits": logits, "probs": probs}
[docs]
@nn.nowrap
def loss(
self, output: Array, aux: dict, targets: Array, **kwargs
) -> Array:
logits = aux["logits"]
targets = jnp.clip(targets, self.v_min, self.v_max)
lower_idx = ((targets - self.v_min) / self.delta_z).astype(jnp.int32)
lower_idx = jnp.clip(lower_idx, 0, self.num_atoms - 2)
upper_idx = lower_idx + 1
upper_weight = (
targets - (self.v_min + lower_idx * self.delta_z)
) / self.delta_z
lower_weight = 1.0 - upper_weight
log_probs = jax.nn.log_softmax(logits, axis=-1)
lower_log_prob = jnp.take_along_axis(
log_probs, lower_idx[..., None].astype(jnp.int32), axis=-1
).squeeze(-1)
upper_log_prob = jnp.take_along_axis(
log_probs, upper_idx[..., None].astype(jnp.int32), axis=-1
).squeeze(-1)
loss = -(lower_weight * lower_log_prob + upper_weight * upper_log_prob)
return loss
[docs]
class Categorical(nn.Module):
action_dim: int
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(self, x: Array, **kwargs) -> tuple[distrax.Categorical, dict]:
logits = nn.Dense(
self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
return distrax.Categorical(logits=logits), {}
[docs]
class Gaussian(nn.Module):
action_dim: int
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(
self, x: Array, **kwargs
) -> tuple[distrax.MultivariateNormalDiag, dict]:
mean = nn.Dense(
self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
log_std = self.param("log_std", nn.initializers.zeros, self.action_dim)
std = jnp.exp(log_std)
return distrax.MultivariateNormalDiag(loc=mean, scale_diag=std), {}
[docs]
class SquashedGaussian(nn.Module):
action_dim: int
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
LOG_STD_MIN = -10
LOG_STD_MAX = 2
[docs]
@nn.compact
def __call__(self, x: Array, **kwargs) -> tuple[distrax.Transformed, dict]:
temperature = kwargs.get("temperature", 1.0)
mean = nn.Dense(
self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
log_std = nn.Dense(
self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
log_std = jnp.clip(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
std = jnp.exp(log_std)
dist = distrax.MultivariateNormalDiag(loc=mean, scale_diag=std * temperature)
return distrax.Transformed(dist, distrax.Block(distrax.Tanh(), ndims=1)), {}
[docs]
class Alpha(nn.Module):
initial_alpha: float
[docs]
@nn.compact
def __call__(self) -> Array:
log_alpha = self.param(
"log_temp",
constant(jnp.log(self.initial_alpha)),
(),
)
return log_alpha
[docs]
class Beta(nn.Module):
initial_beta: float
[docs]
@nn.compact
def __call__(self) -> Array:
log_beta = self.param(
"log_temp",
constant(jnp.log(self.initial_beta)),
(),
)
return log_beta
[docs]
class GVF(nn.Module):
head: nn.Module
gamma: float
cumulant: Callable
[docs]
def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]:
return self.head(x, **kwargs)
[docs]
@nn.nowrap
def get_target(self, transition: PyTree, next_value: Array) -> Array:
next_value = jax.lax.stop_gradient(next_value)
return (
self.cumulant(transition)
+ self.gamma * (1 - transition.second.done) * next_value
)
[docs]
@nn.nowrap
def loss(self, output: Array, aux: dict, targets: Array, **kwargs) -> Array:
return self.head.loss(output, aux, targets, **kwargs)
[docs]
class Horde(nn.Module):
head: nn.Module
demons: dict[str, nn.Module]
[docs]
def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]:
output, aux = self.head(x, **kwargs)
demons = {}
for name, demon in self.demons.items():
demons[name] = demon(x, **kwargs)
return output, {**aux, "demons": demons}
[docs]
@nn.nowrap
def loss(self, output: Array, aux: dict, targets: Array, **kwargs) -> Array:
loss = self.head.loss(output, aux, targets, **kwargs)
transitions = kwargs.get("transitions")
for name, demon in self.demons.items():
values, _ = aux["demons"][name]
padding = ((0, 0, 0),) + ((-1, 1, 0),) + ((0, 0, 0),) * (values.ndim - 2)
next_values = jax.lax.pad(values, 0.0, padding)
demon_targets = demon.get_target(
transitions, remove_feature_axis(next_values)
)
loss = loss + demon.loss(
*aux["demons"][name],
add_feature_axis(demon_targets),
transitions=transitions,
)
return loss
[docs]
class PredecessorHead(nn.Module):
features: int
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(self, x: Array, **kwargs) -> tuple[tuple[Array, Array], dict]:
phi = nn.Dense(
self.features, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
psi_back = nn.Dense(
self.features, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
return (phi, psi_back), {}