Source code for memorax.networks.heads

from typing import Callable

import distrax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen.initializers import constant

from memorax.utils.axes import add_feature_axis, remove_feature_axis
from memorax.utils.typing import Array, PyTree


[docs] class DiscreteQNetwork(nn.Module): action_dim: int kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]: q_values = nn.Dense( self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) return q_values, {}
[docs] def loss( self, output: Array, aux: dict, targets: Array, **kwargs ) -> Array: return 0.5 * jnp.square(output - targets)
[docs] class ContinuousQNetwork(nn.Module): kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__( self, x: Array, *, action: Array, **kwargs ) -> tuple[Array, dict]: q_values = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init)( jnp.concatenate([x, action], axis=-1) ) return jnp.squeeze(q_values, -1), {}
[docs] def loss( self, output: Array, aux: dict, targets: Array, **kwargs ) -> Array: return 0.5 * jnp.square(output - targets)
[docs] class TwinContinuousQNetwork(nn.Module): kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__( self, x: Array, *, action: Array, **kwargs ) -> tuple[tuple[Array, Array], dict]: inp = jnp.concatenate([x, action], axis=-1) q1 = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init, name="q1")(inp) q2 = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init, name="q2")(inp) return (jnp.squeeze(q1, -1), jnp.squeeze(q2, -1)), {}
[docs] def loss( self, output: Array, aux: dict, targets: Array, **kwargs ) -> Array: return 0.5 * jnp.square(output - targets)
[docs] class VNetwork(nn.Module): kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]: v_value = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return v_value, {}
[docs] def loss( self, output: Array, aux: dict, targets: Array, **kwargs ) -> Array: return 0.5 * jnp.square(output - targets)
[docs] class HLGaussVNetwork(nn.Module): """HL-Gauss value head with two-hot cross-entropy loss.""" num_bins: int = 101 v_min: float = -10.0 v_max: float = 10.0 kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] def setup(self): self.bin_width = (self.v_max - self.v_min) / (self.num_bins - 1) self.bin_centers = jnp.linspace(self.v_min, self.v_max, self.num_bins)
[docs] @nn.compact def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]: logits = nn.Dense( self.num_bins, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) probs = jax.nn.softmax(logits, axis=-1) value = jnp.sum(probs * self.bin_centers, axis=-1, keepdims=True) return value, {"logits": logits}
[docs] @nn.nowrap def loss( self, output: Array, aux: dict, targets: Array, **kwargs ) -> Array: """Two-hot cross-entropy loss.""" logits = aux["logits"] targets = remove_feature_axis(targets) targets = jnp.clip(targets, self.v_min, self.v_max) lower_idx = ((targets - self.v_min) / self.bin_width).astype(jnp.int32) lower_idx = jnp.clip(lower_idx, 0, self.num_bins - 2) upper_idx = lower_idx + 1 upper_weight = ( targets - (self.v_min + lower_idx * self.bin_width) ) / self.bin_width lower_weight = 1.0 - upper_weight log_probs = jax.nn.log_softmax(logits, axis=-1) lower_log_prob = jnp.take_along_axis( log_probs, lower_idx[..., None].astype(jnp.int32), axis=-1 ).squeeze(-1) upper_log_prob = jnp.take_along_axis( log_probs, upper_idx[..., None].astype(jnp.int32), axis=-1 ).squeeze(-1) loss = -(lower_weight * lower_log_prob + upper_weight * upper_log_prob) return loss
[docs] class C51QNetwork(nn.Module): action_dim: int num_atoms: int = 51 v_min: float = -10.0 v_max: float = 10.0 kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] def setup(self): self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) self.atoms = jnp.linspace(self.v_min, self.v_max, self.num_atoms)
[docs] @nn.compact def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]: logits = nn.Dense( self.action_dim * self.num_atoms, kernel_init=self.kernel_init, bias_init=self.bias_init, )(x) batch_shape = logits.shape[:-1] logits = logits.reshape(*batch_shape, self.action_dim, self.num_atoms) probs = jax.nn.softmax(logits, axis=-1) q_values = jnp.sum(probs * self.atoms, axis=-1) return q_values, {"logits": logits, "probs": probs}
[docs] @nn.nowrap def loss( self, output: Array, aux: dict, targets: Array, **kwargs ) -> Array: logits = aux["logits"] targets = jnp.clip(targets, self.v_min, self.v_max) lower_idx = ((targets - self.v_min) / self.delta_z).astype(jnp.int32) lower_idx = jnp.clip(lower_idx, 0, self.num_atoms - 2) upper_idx = lower_idx + 1 upper_weight = ( targets - (self.v_min + lower_idx * self.delta_z) ) / self.delta_z lower_weight = 1.0 - upper_weight log_probs = jax.nn.log_softmax(logits, axis=-1) lower_log_prob = jnp.take_along_axis( log_probs, lower_idx[..., None].astype(jnp.int32), axis=-1 ).squeeze(-1) upper_log_prob = jnp.take_along_axis( log_probs, upper_idx[..., None].astype(jnp.int32), axis=-1 ).squeeze(-1) loss = -(lower_weight * lower_log_prob + upper_weight * upper_log_prob) return loss
[docs] class Categorical(nn.Module): action_dim: int kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: Array, **kwargs) -> tuple[distrax.Categorical, dict]: logits = nn.Dense( self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) return distrax.Categorical(logits=logits), {}
[docs] class Gaussian(nn.Module): action_dim: int kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__( self, x: Array, **kwargs ) -> tuple[distrax.MultivariateNormalDiag, dict]: mean = nn.Dense( self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) log_std = self.param("log_std", nn.initializers.zeros, self.action_dim) std = jnp.exp(log_std) return distrax.MultivariateNormalDiag(loc=mean, scale_diag=std), {}
[docs] class SquashedGaussian(nn.Module): action_dim: int kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init() LOG_STD_MIN = -10 LOG_STD_MAX = 2
[docs] @nn.compact def __call__(self, x: Array, **kwargs) -> tuple[distrax.Transformed, dict]: temperature = kwargs.get("temperature", 1.0) mean = nn.Dense( self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) log_std = nn.Dense( self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) log_std = jnp.clip(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX) std = jnp.exp(log_std) dist = distrax.MultivariateNormalDiag(loc=mean, scale_diag=std * temperature) return distrax.Transformed(dist, distrax.Block(distrax.Tanh(), ndims=1)), {}
[docs] class Alpha(nn.Module): initial_alpha: float
[docs] @nn.compact def __call__(self) -> Array: log_alpha = self.param( "log_temp", constant(jnp.log(self.initial_alpha)), (), ) return log_alpha
[docs] class Beta(nn.Module): initial_beta: float
[docs] @nn.compact def __call__(self) -> Array: log_beta = self.param( "log_temp", constant(jnp.log(self.initial_beta)), (), ) return log_beta
[docs] class GVF(nn.Module): head: nn.Module gamma: float cumulant: Callable
[docs] def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]: return self.head(x, **kwargs)
[docs] @nn.nowrap def get_target(self, transition: PyTree, next_value: Array) -> Array: next_value = jax.lax.stop_gradient(next_value) return ( self.cumulant(transition) + self.gamma * (1 - transition.second.done) * next_value )
[docs] @nn.nowrap def loss(self, output: Array, aux: dict, targets: Array, **kwargs) -> Array: return self.head.loss(output, aux, targets, **kwargs)
[docs] class Horde(nn.Module): head: nn.Module demons: dict[str, nn.Module]
[docs] def __call__(self, x: Array, **kwargs) -> tuple[Array, dict]: output, aux = self.head(x, **kwargs) demons = {} for name, demon in self.demons.items(): demons[name] = demon(x, **kwargs) return output, {**aux, "demons": demons}
[docs] @nn.nowrap def loss(self, output: Array, aux: dict, targets: Array, **kwargs) -> Array: loss = self.head.loss(output, aux, targets, **kwargs) transitions = kwargs.get("transitions") for name, demon in self.demons.items(): values, _ = aux["demons"][name] padding = ((0, 0, 0),) + ((-1, 1, 0),) + ((0, 0, 0),) * (values.ndim - 2) next_values = jax.lax.pad(values, 0.0, padding) demon_targets = demon.get_target( transitions, remove_feature_axis(next_values) ) loss = loss + demon.loss( *aux["demons"][name], add_feature_axis(demon_targets), transitions=transitions, ) return loss
[docs] class PredecessorHead(nn.Module): features: int kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: Array, **kwargs) -> tuple[tuple[Array, Array], dict]: phi = nn.Dense( self.features, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) psi_back = nn.Dense( self.features, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) return (phi, psi_back), {}