Source code for memorax.networks.vit
import flax.linen as nn
import jax.numpy as jnp
from memorax.networks.blocks import FFN
from memorax.networks.identity import Identity
from memorax.utils.typing import Array
[docs]
class PatchEmbedding(nn.Module):
"""Converts images to patch sequences via Conv2D."""
patch_size: int = 16
features: int = 768
[docs]
@nn.compact
def __call__(self, x: Array, **kwargs) -> Array:
x = nn.Conv(
self.features,
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
)(x)
return x.reshape(x.shape[0], -1, self.features)
[docs]
class ViT(nn.Module):
"""Vision Transformer feature extractor.
Can operate in two modes:
- Token mode (default): For pre-tokenized inputs (e.g., (B, T, num_tokens, token_dim))
- Image mode: Pass patch_embedding=PatchEmbedding(patch_size, features) to convert images to tokens
Input shape: (B, T, ...) where T is the time/sequence axis.
Output shape: (B, T, features)
"""
features: int = 768
num_layers: int = 12
num_heads: int = 12
expansion_factor: int = 4
patch_embedding: nn.Module = Identity()
[docs]
@nn.compact
def __call__(self, x: Array, **kwargs) -> Array:
batch_size, sequence_length, *_ = x.shape
x = x.reshape(batch_size * sequence_length, *x.shape[2:])
x = self.patch_embedding(x)
x = nn.Dense(self.features)(x)
positional_embedding = self.param(
"positional_embedding",
nn.initializers.normal(0.02),
(1, x.shape[1], self.features),
)
x = x + positional_embedding
for _ in range(self.num_layers):
skip = x
x = nn.LayerNorm()(x)
x = nn.MultiHeadDotProductAttention(num_heads=self.num_heads)(x, x)
x = skip + x
skip = x
x = nn.LayerNorm()(x)
_, x = FFN(
features=self.features, expansion_factor=int(self.expansion_factor)
)(x)
x = skip + x
x = nn.LayerNorm()(x)
x = x.mean(axis=1)
x = x.reshape(batch_size, sequence_length, -1)
return x