Source code for memorax.networks.positional_embeddings.alibi

from typing import Any

import jax.numpy as jnp
from flax import struct

from memorax.utils.typing import Array

from .base import RelativePositionalEmbedding


[docs] @struct.dataclass class ALiBi(RelativePositionalEmbedding): num_heads: int
[docs] def compute_coefficients(self) -> Array: assert ( self.num_heads & (self.num_heads - 1) == 0 ), "num_heads must be a power of 2" ratio = 2 ** (-8 / self.num_heads) return ratio ** jnp.arange(1, self.num_heads + 1)
[docs] def apply(self, query_pos: Array, key_pos: Array) -> Array: slopes = self.compute_coefficients() relative_pos = key_pos[:, None, :] - query_pos[:, :, None] return slopes[None, :, None, None] * relative_pos[:, None, :, :]
[docs] def __call__( self, query: Array, key: Array, query_pos: Array, key_pos: Array ) -> tuple[Array, Array, Any]: bias = self.apply(query_pos, key_pos) return query, key, bias