Source code for memorax.networks.positional_embeddings.alibi
from typing import Any
import jax.numpy as jnp
from flax import struct
from memorax.utils.typing import Array
from .base import RelativePositionalEmbedding
[docs]
@struct.dataclass
class ALiBi(RelativePositionalEmbedding):
num_heads: int
[docs]
def compute_coefficients(self) -> Array:
assert (
self.num_heads & (self.num_heads - 1) == 0
), "num_heads must be a power of 2"
ratio = 2 ** (-8 / self.num_heads)
return ratio ** jnp.arange(1, self.num_heads + 1)
[docs]
def apply(self, query_pos: Array, key_pos: Array) -> Array:
slopes = self.compute_coefficients()
relative_pos = key_pos[:, None, :] - query_pos[:, :, None]
return slopes[None, :, None, None] * relative_pos[:, None, :, :]
[docs]
def __call__(
self, query: Array, key: Array, query_pos: Array, key_pos: Array
) -> tuple[Array, Array, Any]:
bias = self.apply(query_pos, key_pos)
return query, key, bias