Source code for memorax.networks.cnn

from typing import Callable, Optional, Sequence, Union

import flax.linen as nn
import jax.numpy as jnp


[docs] class CNN(nn.Module): features: Sequence[int] kernel_sizes: Sequence[tuple[int, int]] strides: Sequence[int | tuple[int, int]] poolings: Optional[Sequence[Callable]] = None padding: str = "VALID" activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu normalizer: Optional[Callable] = None kernel_init: nn.initializers.Initializer = nn.initializers.lecun_normal() bias_init: nn.initializers.Initializer = nn.initializers.zeros_init()
[docs] @nn.compact def __call__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray: poolings = self.poolings or [None] * len(self.features) for feature, kernel_size, stride, pooling in zip( self.features, self.kernel_sizes, self.strides, poolings ): x = nn.Conv( feature, kernel_size=kernel_size, strides=stride, padding=self.padding, kernel_init=self.kernel_init, bias_init=self.bias_init, )(x) if self.normalizer is not None: x = self.normalizer()(x) x = self.activation(x) if pooling is not None: x = pooling(x) batch_size, sequence_length, *_ = x.shape x = x.reshape((batch_size, sequence_length, -1)) return x