Source code for memorax.networks.network
from typing import Optional
import flax.linen as nn
import jax
from memorax.networks import Identity
from memorax.networks.sequence_models.wrappers import SequenceModelWrapper
from memorax.utils.typing import Array
[docs]
class Network(nn.Module):
feature_extractor: nn.Module = Identity()
torso: nn.Module = SequenceModelWrapper(Identity())
head: nn.Module = Identity()
[docs]
@nn.compact
def __call__(
self,
observation: Array,
mask: Array,
**kwargs,
):
x = observation
x = self.feature_extractor(observation, **kwargs)
carry, x = self.torso(x, mask=mask, **kwargs)
x = self.head(x, **kwargs)
return carry, x
[docs]
@nn.nowrap
def initialize_carry(self, input_shape):
key = jax.random.key(0)
carry = None
if self.torso is not None:
carry = self.torso.initialize_carry(key, input_shape)
return carry