Source code for memorax.networks.heads

from typing import Callable, Optional

import distrax
import flax.linen as nn
import jax.numpy as jnp
from flax.linen.initializers import constant


[docs] class DiscreteQNetwork(nn.Module): action_dim: int kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray: q_values = nn.Dense( self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) return q_values
[docs] class ContinuousQNetwork(nn.Module): kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: jnp.ndarray, *, action: jnp.ndarray, **kwargs) -> jnp.ndarray: q_values = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init)( jnp.concatenate([x, action], axis=-1) ) return jnp.squeeze(q_values, -1)
[docs] class VNetwork(nn.Module): kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray: v_value = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return v_value
[docs] class Categorical(nn.Module): action_dim: int kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: jnp.ndarray, **kwargs) -> distrax.Categorical: logits = nn.Dense( self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) return distrax.Categorical(logits=logits)
[docs] class Gaussian(nn.Module): action_dim: int transform: Callable | distrax.Bijector = lambda x: x kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: jnp.ndarray, **kwargs): mean = nn.Dense( self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) log_std = self.param("log_std", nn.initializers.zeros, self.action_dim) std = jnp.exp(log_std) dist = distrax.MultivariateNormalDiag(loc=mean, scale_diag=std) return distrax.Transformed(dist, self.transform)
[docs] class SquashedGaussian(nn.Module): action_dim: int kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init() LOG_STD_MIN = -10 LOG_STD_MAX = 2
[docs] @nn.compact def __call__(self, x: jnp.ndarray, **kwargs): temperature = kwargs.get("temperature", 1.0) mean = nn.Dense( self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) log_std = nn.Dense( self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init )(x) log_std = jnp.clip(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX) std = jnp.exp(log_std) dist = distrax.MultivariateNormalDiag(loc=mean, scale_diag=std * temperature) return distrax.Transformed(dist, distrax.Block(distrax.Tanh(), ndims=1))
[docs] class Alpha(nn.Module): initial_alpha: float
[docs] @nn.compact def __call__(self) -> jnp.ndarray: log_alpha = self.param( "log_temp", constant(jnp.log(self.initial_alpha)), (), ) return log_alpha
class Beta(nn.Module): initial_beta: float @nn.compact def __call__(self) -> jnp.ndarray: log_beta = self.param( "log_temp", constant(jnp.log(self.initial_beta)), (), ) return log_beta