Source code for memorax.networks.sequence_models.rnn

from typing import Mapping, Optional

import flax.linen as nn
import jax
from flax.core.frozen_dict import FrozenDict
from flax.core.scope import CollectionFilter, PRNGSequenceFilter
from flax.linen import initializers
from flax.linen.linear import default_kernel_init
from flax.linen.recurrent import Carry
from flax.typing import Initializer, InOutScanAxis
from memorax.networks.sequence_models.utils import (
    get_time_axis_and_input_shape,
    mask_carry,
)

from .sequence_model import SequenceModel


[docs] class RNN(SequenceModel): cell: nn.RNNCellBase unroll: int = 1 variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict() variable_broadcast: CollectionFilter = "params" variable_carry: CollectionFilter = False split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict({"params": False}) kernel_init: Initializer = default_kernel_init bias_init: Initializer = initializers.zeros_init()
[docs] @nn.compact def __call__( self, inputs: jax.Array, mask: jax.Array, initial_carry: Optional[Carry] = None, **kwargs, ): time_axis, input_shape = get_time_axis_and_input_shape(inputs) if initial_carry is None: initial_carry = self.cell.initialize_carry(jax.random.key(0), input_shape) carry: Carry = initial_carry def scan_fn(cell, carry, x, mask): carry = mask_carry( mask, carry, self.cell.initialize_carry(jax.random.key(0), input_shape) ) carry, y = cell(carry, x) return carry, y scan = nn.transforms.scan( scan_fn, in_axes=time_axis, out_axes=time_axis, unroll=self.unroll, variable_axes=self.variable_axes, variable_broadcast=self.variable_broadcast, variable_carry=self.variable_carry, split_rngs=self.split_rngs, ) carry, outputs = scan(self.cell, carry, inputs, mask) return carry, outputs
[docs] @nn.nowrap def initialize_carry(self, key, input_shape): return self.cell.initialize_carry(key, input_shape)