Source code for memorax.networks.vit

import flax.linen as nn
import jax.numpy as jnp

from memorax.networks.blocks import FFN


class PatchEmbedding(nn.Module):
    """Converts images to patch sequences via Conv2D."""

    patch_size: int = 16
    features: int = 768

    @nn.compact
    def __call__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
        x = nn.Conv(
            self.features,
            kernel_size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
        )(x)
        return x.reshape(x.shape[0], -1, self.features)


[docs] class ViT(nn.Module): """Vision Transformer feature extractor.""" patch_size: int = 16 features: int = 768 num_layers: int = 12 num_heads: int = 12 expansion_factor: int = 4
[docs] @nn.compact def __call__(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray: x = PatchEmbedding(self.patch_size, self.features)(x) positional_embeddin = self.param( nn.initializers.normal(0.02), (1, x.shape[1], self.features) ) x = x + positional_embeddin for _ in range(self.num_layers): skip = x x = nn.LayerNorm()(x) x = nn.MultiHeadDotProductAttention(num_heads=self.num_heads)(x, x) x = skip + x skip = x x = nn.LayerNorm()(x) _, x = FFN( features=self.features, expansion_factor=int(self.expansion_factor) )(x) x = skip + x x = nn.LayerNorm()(x) x = x.mean(axis=1) return x