Source code for memorax.networks.sequence_models.self_attention

from functools import partial
from typing import Literal

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct
from flax.typing import Dtype, Initializer

from memorax.networks.positional_embeddings import RelativePositionalEmbedding
from memorax.utils.axes import get_input_shape
from memorax.utils.typing import Array, Key

Implementation = Literal["xla", "cudnn"]


def get_attention_implementation() -> tuple[Implementation, jnp.dtype]:
    backend = jax.default_backend()
    if backend == "gpu":
        try:
            if any(
                "nvidia" in device.device_kind.lower() for device in jax.local_devices()
            ):
                return "cudnn", jnp.bfloat16
        except Exception:
            pass
    return "xla", jnp.float32


def get_attention_mask(done, initial_carry, memory_done, context_length, num_heads):
    B, T = done.shape
    _, M, *_ = memory_done.shape

    query_mask = (
        jnp.cumsum(done.astype(jnp.int32), axis=1)
        + jnp.max(
            jnp.cumsum(
                jnp.concatenate([memory_done, initial_carry.done], axis=1), axis=1
            ),
            axis=1,
        )[..., None]
    )

    key_mask = jnp.concatenate(
        [memory_done, initial_carry.done, done], axis=1, dtype=jnp.int32
    )
    key_mask = jnp.cumsum(key_mask, axis=1)
    key_mask = key_mask[:, -(M + context_length) :]

    attention_mask = nn.make_attention_mask(query_mask, key_mask, pairwise_fn=jnp.equal)

    query_input = jnp.arange(T) + M + context_length
    query_input = jnp.broadcast_to(query_input, (B, T))
    key_input = jnp.arange(M + context_length + T)
    key_input = jnp.broadcast_to(key_input, (B, M + context_length + T))
    key_input = key_input[:, -(M + context_length) :]
    causal_mask = nn.make_attention_mask(
        query_input, key_input, pairwise_fn=jnp.greater_equal
    )

    B, _, T, S = attention_mask.shape
    attention_mask = jnp.broadcast_to(attention_mask, (B, num_heads, T, S))

    B, _, T, S = causal_mask.shape
    causal_mask = jnp.broadcast_to(causal_mask, (B, num_heads, T, S))

    combined_mask = nn.combine_masks(attention_mask, causal_mask, dtype=jnp.bool)
    return combined_mask, query_input, key_input


from .sequence_model import SequenceModel


@struct.dataclass
class SelfAttentionConfig:
    features: int
    num_heads: int
    context_length: int
    dtype: Dtype = jnp.float32
    param_dtype: Dtype = jnp.float32
    kernel_init: Initializer = struct.field(pytree_node=False, default=nn.initializers.lecun_normal())
    bias_init: Initializer = struct.field(pytree_node=False, default=nn.initializers.zeros_init())
    positional_embedding: RelativePositionalEmbedding = struct.field(
        pytree_node=False,
        default=lambda query, key, query_pos, key_pos: (query, key, None),
    )


@struct.dataclass
class SelfAttentionCarry:
    done: Array
    key: Array
    value: Array


[docs] class SelfAttention(SequenceModel): config: SelfAttentionConfig
[docs] def setup(self): head_dim = self.config.features // self.config.num_heads projection = partial( nn.DenseGeneral, features=(self.config.num_heads, head_dim), dtype=self.config.dtype, param_dtype=self.config.param_dtype, kernel_init=self.config.kernel_init, bias_init=self.config.bias_init, ) self.query = projection() self.key = projection() self.value = projection() self.output_projection = nn.DenseGeneral( self.config.features, axis=(-2, -1), dtype=self.config.dtype, param_dtype=self.config.param_dtype, kernel_init=self.config.kernel_init, bias_init=self.config.bias_init, )
[docs] @nn.nowrap def initialize_carry(self, key: Key, input_shape: tuple) -> SelfAttentionCarry: *batch_dims, _ = input_shape head_dim = self.config.features // self.config.num_heads done = jnp.ones((*batch_dims, self.config.context_length), dtype=jnp.int32) key = jnp.zeros( (*batch_dims, self.config.context_length, self.config.num_heads, head_dim), dtype=self.config.dtype, ) value = jnp.zeros( (*batch_dims, self.config.context_length, self.config.num_heads, head_dim), dtype=self.config.dtype, ) return SelfAttentionCarry(done, key, value)
[docs] def __call__( self, x, done, initial_carry: SelfAttentionCarry | None = None, memory: Array | None = None, memory_done: Array | None = None, **kwargs, ) -> tuple[SelfAttentionCarry, Array]: if initial_carry is None: input_shape = get_input_shape(x) initial_carry = self.initialize_carry(jax.random.key(0), input_shape) B, T, *_ = x.shape if memory is None: memory = jnp.zeros((B, 0, self.config.features), dtype=self.config.dtype) memory_done = jnp.zeros((B, 0), dtype=jnp.int32) _, M, *_ = memory.shape assert ( T <= self.config.context_length ), f"T must be less than or equal to context_length, but was T: {T}, context_length: {self.config.context_length}" query = self.query(x) key = self.key(jnp.concatenate([memory, x], axis=1)) key = jnp.concatenate([key[:, :M], initial_carry.key, key[:, M:]], axis=1) key = key[:, -(M + self.config.context_length) :] value = self.value(jnp.concatenate([memory, x], axis=1)) value = jnp.concatenate( [value[:, :M], initial_carry.value, value[:, M:]], axis=1 ) value = value[:, -(M + self.config.context_length) :] attention_mask, query_input, key_input = get_attention_mask( done, initial_carry, memory_done, self.config.context_length, self.config.num_heads ) query, key, bias = self.config.positional_embedding(query, key, query_input, key_input) implementation, attention_dtype = get_attention_implementation() x = jax.nn.dot_product_attention( query.astype(attention_dtype), key.astype(attention_dtype), value.astype(attention_dtype), bias=bias, mask=attention_mask, implementation=implementation, ).astype(self.config.dtype) y = self.output_projection(x) done = jnp.concatenate([initial_carry.done, done], axis=1)[ :, -self.config.context_length : ] key = key[:, -self.config.context_length :] value = value[:, -self.config.context_length :] carry = initial_carry.replace(done=done, key=key, value=value) return carry, y