Source code for memorax.networks.network

import flax.linen as nn
import jax

from memorax.networks import Identity
from memorax.utils.typing import Array, Carry


[docs] class Network(nn.Module): feature_extractor: nn.Module = Identity() torso: nn.Module = Identity() head: nn.Module = Identity()
[docs] @nn.compact def __call__( self, observation: Array, done: Array, action: Array, reward: Array, initial_carry: Array | None = None, **kwargs, ) -> tuple[Carry, Array]: x, embeddings = self.feature_extractor( observation, action=action, reward=reward, done=done ) match self.torso( x, done=done, action=action, reward=reward, initial_carry=initial_carry, **embeddings, **kwargs, ): case (carry, x): pass case x: carry = None x = self.head(x, action=action, reward=reward, done=done, **kwargs) return carry, x
[docs] @nn.nowrap def initialize_carry(self, input_shape: tuple) -> Carry: key = jax.random.key(0) return getattr(self.torso, "initialize_carry", lambda k, s: None)( key, input_shape )