Source code for memorax.networks.positional_embeddings.wpe
import jax
import jax.numpy as jnp
from flax import linen as nn
def _check_positions(positions, num_embeddings):
if not jnp.all(positions < num_embeddings):
raise ValueError(
f"Position indices exceed num_embeddings ({num_embeddings}). "
f"Max position: {jnp.max(positions)}. "
"Ensure num_embeddings >= context_length or that episode resets occur frequently enough."
)
from memorax.utils.typing import Array, Carry, Key
from .base import AbsolutePositionalEmbedding
[docs]
class LearnablePositionalEmbedding(AbsolutePositionalEmbedding, nn.Module):
num_embeddings: int
features: int
[docs]
@nn.nowrap
def initialize_carry(self, key: Key, input_shape: tuple) -> Carry:
*batch_dims, _ = input_shape
return jnp.zeros(batch_dims, dtype=jnp.int32)
[docs]
@nn.compact
def __call__(
self,
inputs: Array,
done: Array,
initial_carry: Carry | None = None,
**kwargs,
) -> tuple[Array, Carry]:
batch_size = inputs.shape[0]
if initial_carry is None:
initial_carry = self.initialize_carry(None, (batch_size, inputs.shape[-1]))
def step(position: Array, done: Array) -> tuple[Array, Array]:
position = jnp.where(done, 0, position)
return position + 1, position
def compute_positions(done: Array, offset: Array) -> tuple[Array, Array]:
carry, positions = jax.lax.scan(step, offset, done)
return positions, carry
positions, carry = jax.vmap(compute_positions)(done, initial_carry)
jax.debug.callback(_check_positions, positions, self.num_embeddings)
position_embeddings = nn.Embed(
num_embeddings=self.num_embeddings, features=self.features
)(positions)
return carry, inputs + position_embeddings