Source code for memorax.networks.layers.flatten

from math import prod

import jax.numpy as jnp
from flax import linen as nn

from memorax.utils.typing import Array


[docs] class Flatten(nn.Module): start_dim: int = 1 end_dim: int = -1
[docs] @nn.compact def __call__(self, x: Array) -> Array: shape = x.shape ndim = len(shape) start = self.start_dim if self.start_dim >= 0 else ndim + self.start_dim end = self.end_dim if self.end_dim >= 0 else ndim + self.end_dim new_shape = shape[:start] + (prod(shape[start : end + 1]),) + shape[end + 1 :] return x.reshape(new_shape)