Source code for memorax.networks.layers.causal_conv1d
import jax.numpy as jnp
from flax import linen as nn
from memorax.networks.initializers import bounded_uniform, kaiming_uniform
from memorax.utils.typing import Array
[docs]
class CausalConv1d(nn.Module):
features: int
kernel_size: int = 4
use_bias: bool = True
param_dtype: jnp.dtype = jnp.float32
[docs]
@nn.compact
def __call__(self, x: Array, state: Array) -> tuple:
kernel = self.param(
"kernel",
kaiming_uniform(),
(self.kernel_size, self.features),
self.param_dtype,
)
conv_state = jnp.concatenate([state[:, 1:, :], x], axis=1)
y = jnp.einsum("bkf,kf->bf", conv_state, kernel)[:, None, :]
if self.use_bias:
bias = self.param(
"bias", nn.initializers.zeros_init(), (self.features,), self.param_dtype
)
y = y + bias
return conv_state, y
[docs]
class ParallelCausalConv1d(nn.Module):
features: int
kernel_size: int = 4
use_bias: bool = True
param_dtype: jnp.dtype = jnp.float32
dtype: jnp.dtype = jnp.float32
[docs]
@nn.compact
def __call__(self, x: Array):
*_, feature_group_count = x.shape
padding = self.kernel_size - 1
x = nn.Conv(
features=self.features,
kernel_size=self.kernel_size,
kernel_init=kaiming_uniform(),
bias_init=bounded_uniform(
min_val=-1.0 / jnp.sqrt(self.kernel_size),
max_val=1.0 / jnp.sqrt(self.kernel_size),
),
feature_group_count=feature_group_count,
padding=[(padding, 0)],
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
)(x)
return x