Source code for memorax.networks.blocks.router
from typing import NamedTuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from memorax.utils.typing import Array
[docs]
class TopKRouter(nn.Module):
"""Top-K router for Mixture of Experts."""
num_experts: int
k: int = 2
[docs]
@nn.compact
def __call__(self, inputs: Array) -> tuple:
batch_size, seq_len, _ = inputs.shape
num_tokens = batch_size * seq_len
logits = nn.Dense(self.num_experts, use_bias=False)(inputs)
probs = jax.nn.softmax(logits, axis=-1)
top_k_weights, top_k_indices = jax.lax.top_k(probs, self.k)
weights = top_k_weights / (top_k_weights.sum(axis=-1, keepdims=True) + 1e-9)
mask = jax.nn.one_hot(top_k_indices, self.num_experts).sum(axis=-2)
mask = jnp.minimum(mask, 1.0)
fraction = mask.sum(axis=(0, 1)) / num_tokens
loss = jnp.mean(
jnp.square(jax.nn.logsumexp(logits, axis=-1))
) + self.num_experts * jnp.sum(fraction * probs.mean(axis=(0, 1)))
self.sow(
"intermediates",
"moe_loss",
loss,
)
return weights, top_k_indices