Source code for memorax.networks.blocks.moe

from typing import Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp

from memorax.utils.axes import get_input_shape
from memorax.utils.typing import Array, Carry, Key

from .base import Block
from .router import TopKRouter


[docs] class MoE(nn.Module, Block): """Mixture of Experts block for horizontal scaling.""" experts: Sequence[nn.Module] router: TopKRouter
[docs] @nn.compact def __call__( self, inputs: Array, done: Array | None = None, initial_carry: Carry | None = None, **kwargs, ) -> tuple[Carry, Array]: if initial_carry is None: initial_carry = self.initialize_carry( jax.random.key(0), get_input_shape(inputs) ) weights, indices = self.router(inputs) batch_size, seq_len, _ = inputs.shape outputs, carry = [], [] for expert, carry_i in zip(self.experts, initial_carry): carry_i, x = expert(inputs, done=done, initial_carry=carry_i, **kwargs) outputs.append(x) carry.append(carry_i) stacked = jnp.stack(outputs, axis=0) batch_indices = jnp.broadcast_to( jnp.arange(batch_size)[:, None, None], indices.shape ) sequence_indices = jnp.broadcast_to( jnp.arange(seq_len)[None, :, None], indices.shape ) selected = stacked[indices, batch_indices, sequence_indices] output = jnp.einsum("bskf,bsk->bsf", selected, weights) return tuple(carry), output
[docs] @nn.nowrap def initialize_carry(self, key: Key, input_shape: tuple) -> Carry: return tuple( expert.initialize_carry(key, input_shape) for expert in self.experts )