Source code for memorax.networks.sequence_models.self_attention

from functools import partial
from typing import Any, Optional

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct

from memorax.networks.positional_embeddings import RelativePositionalEmbedding
from memorax.networks.sequence_models.utils import (
    get_attention_implementation,
    get_attention_mask,
    get_input_shape,
)
from memorax.utils.typing import Array

from .sequence_model import SequenceModel


@struct.dataclass
class Carry:
    mask: Array
    key: Array
    value: Array


[docs] class SelfAttention(SequenceModel): features: int num_heads: int context_length: int dtype: Any param_dtype: Any kernel_init: Any bias_init: Any positional_embedding: RelativePositionalEmbedding = lambda query, key, query_pos, key_pos: ( query, key, None, )
[docs] @nn.nowrap def initialize_carry(self, key, input_shape): batch_size, *_ = input_shape head_dim = self.features // self.num_heads mask = jnp.ones((batch_size, self.context_length), dtype=jnp.int32) key = jnp.zeros( (batch_size,) + (self.context_length, self.num_heads, head_dim), dtype=self.dtype, ) value = jnp.zeros( (batch_size,) + (self.context_length, self.num_heads, head_dim), dtype=self.dtype, ) return Carry(mask, key, value)
[docs] @nn.compact def __call__( self, x, mask, initial_carry: Optional[Carry] = None, memory: Optional[Array] = None, memory_mask: Optional[Array] = None, **kwargs, ): if initial_carry is None: input_shape = get_input_shape(x) initial_carry = self.initialize_carry(jax.random.key(0), input_shape) B, T, *_ = x.shape head_dim = self.features // self.num_heads if memory is None: memory = jnp.zeros((B, 0, self.features), dtype=self.dtype) memory_mask = jnp.zeros((B, 0), dtype=jnp.int32) _, M, *_ = memory.shape assert ( T <= self.context_length ), f"T must be less than or equal to context_length, but was T: {T}, context_length: {self.context_length}" projection = partial( nn.DenseGeneral, features=(self.num_heads, head_dim), dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, ) query = projection(name="query")(x) key = projection(name="key")(jnp.concatenate([memory, x], axis=1)) key = jnp.concatenate([key[:, :M], initial_carry.key, key[:, M:]], axis=1) key = key[:, -(M + self.context_length) :] value = projection(name="value")(jnp.concatenate([memory, x], axis=1)) value = jnp.concatenate( [value[:, :M], initial_carry.value, value[:, M:]], axis=1 ) value = value[:, -(M + self.context_length) :] attention_mask, query_input, key_input = get_attention_mask( mask, initial_carry, memory_mask, self.context_length, self.num_heads ) query, key, bias = self.positional_embedding(query, key, query_input, key_input) implementation, attention_dtype = get_attention_implementation() x = jax.nn.dot_product_attention( query.astype(attention_dtype), key.astype(attention_dtype), value.astype(attention_dtype), bias=bias, mask=attention_mask, implementation=implementation, ).astype(self.dtype) y = nn.DenseGeneral( self.features, axis=(-2, -1), dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, name="out", )(x) mask = jnp.concatenate([initial_carry.mask, mask], axis=1)[ :, -self.context_length : ] key = key[:, -self.context_length :] value = value[:, -self.context_length :] carry = initial_carry.replace(mask=mask, key=key, value=value) return carry, y