Source code for memorax.networks.positional_embeddings.rope
from typing import Any
import jax.numpy as jnp
from flax import struct
from memorax.utils.typing import Array
from .base import RelativePositionalEmbedding
[docs]
@struct.dataclass
class RoPE(RelativePositionalEmbedding):
base: float = 10000.0
[docs]
def compute_coefficients(self, dim: int, max_seq_len: int) -> Array:
t = jnp.arange(max_seq_len, dtype=jnp.float32)
frequencies = jnp.outer(
t,
1.0 / (self.base ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim)),
)
return frequencies
[docs]
def rotate(self, x: Array) -> Array:
x1, x2 = jnp.split(x, 2, axis=-1)
return jnp.concatenate([-x2, x1], axis=-1)
[docs]
def apply(self, x: Array, positions: Array) -> Array:
head_dim = x.shape[-1]
t = positions.astype(jnp.float32)
freqs = 1.0 / (
self.base ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
)
frequencies = t[..., None] * freqs
frequencies = frequencies[:, :, None, :]
cos = jnp.cos(frequencies)
sin = jnp.sin(frequencies)
cos = jnp.concatenate([cos, cos], axis=-1)
sin = jnp.concatenate([sin, sin], axis=-1)
return x * cos + self.rotate(x) * sin
[docs]
def __call__(
self, query: Array, key: Array, query_pos: Array, key_pos: Array
) -> tuple[Array, Array, Any]:
query = self.apply(query, query_pos)
key = self.apply(key, key_pos)
return query, key, None