from functools import partial
from typing import Literal
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import struct
from flax.typing import Dtype, Initializer
from memorax.networks.positional_embeddings import RelativePositionalEmbedding
from memorax.utils.axes import get_input_shape
from memorax.utils.typing import Array, Key
Implementation = Literal["xla", "cudnn"]
def get_attention_implementation() -> tuple[Implementation, jnp.dtype]:
backend = jax.default_backend()
if backend == "gpu":
try:
if any(
"nvidia" in device.device_kind.lower() for device in jax.local_devices()
):
return "cudnn", jnp.bfloat16
except Exception:
pass
return "xla", jnp.float32
def get_attention_mask(done, initial_carry, memory_done, context_length, num_heads):
B, T = done.shape
_, M, *_ = memory_done.shape
query_mask = (
jnp.cumsum(done.astype(jnp.int32), axis=1)
+ jnp.max(
jnp.cumsum(
jnp.concatenate([memory_done, initial_carry.done], axis=1), axis=1
),
axis=1,
)[..., None]
)
key_mask = jnp.concatenate(
[memory_done, initial_carry.done, done], axis=1, dtype=jnp.int32
)
key_mask = jnp.cumsum(key_mask, axis=1)
key_mask = key_mask[:, -(M + context_length) :]
attention_mask = nn.make_attention_mask(query_mask, key_mask, pairwise_fn=jnp.equal)
query_input = jnp.arange(T) + M + context_length
query_input = jnp.broadcast_to(query_input, (B, T))
key_input = jnp.arange(M + context_length + T)
key_input = jnp.broadcast_to(key_input, (B, M + context_length + T))
key_input = key_input[:, -(M + context_length) :]
causal_mask = nn.make_attention_mask(
query_input, key_input, pairwise_fn=jnp.greater_equal
)
B, _, T, S = attention_mask.shape
attention_mask = jnp.broadcast_to(attention_mask, (B, num_heads, T, S))
B, _, T, S = causal_mask.shape
causal_mask = jnp.broadcast_to(causal_mask, (B, num_heads, T, S))
combined_mask = nn.combine_masks(attention_mask, causal_mask, dtype=jnp.bool)
return combined_mask, query_input, key_input
from .sequence_model import SequenceModel
@struct.dataclass
class SelfAttentionConfig:
features: int
num_heads: int
context_length: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
kernel_init: Initializer = struct.field(pytree_node=False, default=nn.initializers.lecun_normal())
bias_init: Initializer = struct.field(pytree_node=False, default=nn.initializers.zeros_init())
positional_embedding: RelativePositionalEmbedding = struct.field(
pytree_node=False,
default=lambda query, key, query_pos, key_pos: (query, key, None),
)
@struct.dataclass
class SelfAttentionCarry:
done: Array
key: Array
value: Array
[docs]
class SelfAttention(SequenceModel):
config: SelfAttentionConfig
[docs]
def setup(self):
head_dim = self.config.features // self.config.num_heads
projection = partial(
nn.DenseGeneral,
features=(self.config.num_heads, head_dim),
dtype=self.config.dtype,
param_dtype=self.config.param_dtype,
kernel_init=self.config.kernel_init,
bias_init=self.config.bias_init,
)
self.query = projection()
self.key = projection()
self.value = projection()
self.output_projection = nn.DenseGeneral(
self.config.features,
axis=(-2, -1),
dtype=self.config.dtype,
param_dtype=self.config.param_dtype,
kernel_init=self.config.kernel_init,
bias_init=self.config.bias_init,
)
[docs]
@nn.nowrap
def initialize_carry(self, key: Key, input_shape: tuple) -> SelfAttentionCarry:
*batch_dims, _ = input_shape
head_dim = self.config.features // self.config.num_heads
done = jnp.ones((*batch_dims, self.config.context_length), dtype=jnp.int32)
key = jnp.zeros(
(*batch_dims, self.config.context_length, self.config.num_heads, head_dim),
dtype=self.config.dtype,
)
value = jnp.zeros(
(*batch_dims, self.config.context_length, self.config.num_heads, head_dim),
dtype=self.config.dtype,
)
return SelfAttentionCarry(done, key, value)
[docs]
def __call__(
self,
x,
done,
initial_carry: SelfAttentionCarry | None = None,
memory: Array | None = None,
memory_done: Array | None = None,
**kwargs,
) -> tuple[SelfAttentionCarry, Array]:
if initial_carry is None:
input_shape = get_input_shape(x)
initial_carry = self.initialize_carry(jax.random.key(0), input_shape)
B, T, *_ = x.shape
if memory is None:
memory = jnp.zeros((B, 0, self.config.features), dtype=self.config.dtype)
memory_done = jnp.zeros((B, 0), dtype=jnp.int32)
_, M, *_ = memory.shape
assert (
T <= self.config.context_length
), f"T must be less than or equal to context_length, but was T: {T}, context_length: {self.config.context_length}"
query = self.query(x)
key = self.key(jnp.concatenate([memory, x], axis=1))
key = jnp.concatenate([key[:, :M], initial_carry.key, key[:, M:]], axis=1)
key = key[:, -(M + self.config.context_length) :]
value = self.value(jnp.concatenate([memory, x], axis=1))
value = jnp.concatenate(
[value[:, :M], initial_carry.value, value[:, M:]], axis=1
)
value = value[:, -(M + self.config.context_length) :]
attention_mask, query_input, key_input = get_attention_mask(
done, initial_carry, memory_done, self.config.context_length, self.config.num_heads
)
query, key, bias = self.config.positional_embedding(query, key, query_input, key_input)
implementation, attention_dtype = get_attention_implementation()
x = jax.nn.dot_product_attention(
query.astype(attention_dtype),
key.astype(attention_dtype),
value.astype(attention_dtype),
bias=bias,
mask=attention_mask,
implementation=implementation,
).astype(self.config.dtype)
y = self.output_projection(x)
done = jnp.concatenate([initial_carry.done, done], axis=1)[
:, -self.config.context_length :
]
key = key[:, -self.config.context_length :]
value = value[:, -self.config.context_length :]
carry = initial_carry.replace(done=done, key=key, value=value)
return carry, y