from typing import Callable, Optional
import distrax
import flax.linen as nn
import jax.numpy as jnp
from flax.linen.initializers import constant
[docs]
class DiscreteQNetwork(nn.Module):
action_dim: int
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
q_values = nn.Dense(
self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
return q_values
[docs]
class ContinuousQNetwork(nn.Module):
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(self, x: jnp.ndarray, *, action: jnp.ndarray, **kwargs) -> jnp.ndarray:
q_values = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init)(
jnp.concatenate([x, action], axis=-1)
)
return jnp.squeeze(q_values, -1)
[docs]
class VNetwork(nn.Module):
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
v_value = nn.Dense(1, kernel_init=self.kernel_init, bias_init=self.bias_init)(x)
return v_value
[docs]
class Categorical(nn.Module):
action_dim: int
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(self, x: jnp.ndarray, **kwargs) -> distrax.Categorical:
logits = nn.Dense(
self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
return distrax.Categorical(logits=logits)
[docs]
class Gaussian(nn.Module):
action_dim: int
transform: Callable | distrax.Bijector = lambda x: x
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs]
@nn.compact
def __call__(self, x: jnp.ndarray, **kwargs):
mean = nn.Dense(
self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
log_std = self.param("log_std", nn.initializers.zeros, self.action_dim)
std = jnp.exp(log_std)
dist = distrax.MultivariateNormalDiag(loc=mean, scale_diag=std)
return distrax.Transformed(dist, self.transform)
[docs]
class SquashedGaussian(nn.Module):
action_dim: int
kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal()
bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
LOG_STD_MIN = -10
LOG_STD_MAX = 2
[docs]
@nn.compact
def __call__(self, x: jnp.ndarray, **kwargs):
temperature = kwargs.get("temperature", 1.0)
mean = nn.Dense(
self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
log_std = nn.Dense(
self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init
)(x)
log_std = jnp.clip(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
std = jnp.exp(log_std)
dist = distrax.MultivariateNormalDiag(loc=mean, scale_diag=std * temperature)
return distrax.Transformed(dist, distrax.Block(distrax.Tanh(), ndims=1))
[docs]
class Alpha(nn.Module):
initial_alpha: float
[docs]
@nn.compact
def __call__(self) -> jnp.ndarray:
log_alpha = self.param(
"log_temp",
constant(jnp.log(self.initial_alpha)),
(),
)
return log_alpha
class Beta(nn.Module):
initial_beta: float
@nn.compact
def __call__(self) -> jnp.ndarray:
log_beta = self.param(
"log_temp",
constant(jnp.log(self.initial_beta)),
(),
)
return log_beta