Source code for memorax.networks.sequence_models.memoroid

from abc import abstractmethod
from typing import Optional, Tuple

import jax
import jax.numpy as jnp
from flax import linen as nn
from memorax.utils.typing import Array, Carry

from .sequence_model import SequenceModel
from .utils import broadcast_mask, get_input_shape


[docs] class MemoroidCellBase(nn.Module):
[docs] @abstractmethod def __call__(self, x: Array, **kwargs) -> Carry: ...
[docs] @abstractmethod def binary_operator(self, a: Carry, b: Carry) -> Carry: ...
[docs] @abstractmethod def read(self, h: Carry, x: Array, **kwargs) -> Array: ...
[docs] @abstractmethod def initialize_carry(
self, key: jax.Array, input_shape: Tuple[int, ...] ) -> Carry: ...
[docs] class Memoroid(SequenceModel): cell: MemoroidCellBase
[docs] @nn.compact def __call__( self, inputs: Array, mask: Array, initial_carry: Optional[Carry] = None, **kwargs, ) -> Tuple[Carry, Array]: if initial_carry is None: input_shape = get_input_shape(inputs) initial_carry = self.cell.initialize_carry( jax.random.key(0), input_shape ) z = self.cell(inputs, **kwargs) z = jax.tree.map( lambda c, e: jnp.concatenate([c, e], axis=1), initial_carry, z, ) reset = jnp.concatenate([jnp.zeros((mask.shape[0], 1)), mask], axis=1) reset = reset[..., None] @jax.vmap def binary_operator(lhs, rhs): lhs_i, rhs_i = lhs lhs_j, rhs_j = rhs combined = self.cell.binary_operator(lhs_i, lhs_j) out = jax.tree.map( lambda lj, c: jnp.where(broadcast_mask(rhs_j, lj), lj, c), lhs_j, combined, ) return out, jnp.maximum(rhs_i, rhs_j) h, _ = jax.lax.associative_scan(binary_operator, (z, reset), axis=1) next_carry = jax.tree.map(lambda s: s[:, -1:], h) h = jax.tree.map(lambda s: s[:, 1:], h) y = self.cell.read(h, inputs, **kwargs) return next_carry, y
[docs] def initialize_carry(self, key: jax.Array, input_shape: Tuple[int, ...]) -> Carry: return self.cell.initialize_carry(key, input_shape)