Source code for memorax.networks.network
import flax.linen as nn
import jax
from memorax.networks import Identity
from memorax.utils.typing import Array, Carry
[docs]
class Network(nn.Module):
feature_extractor: nn.Module = Identity()
torso: nn.Module = Identity()
head: nn.Module = Identity()
[docs]
@nn.compact
def __call__(
self,
observation: Array,
done: Array,
action: Array,
reward: Array,
initial_carry: Array | None = None,
**kwargs,
) -> tuple[Carry, Array]:
x, embeddings = self.feature_extractor(
observation, action=action, reward=reward, done=done
)
match self.torso(
x,
done=done,
action=action,
reward=reward,
initial_carry=initial_carry,
**embeddings,
**kwargs,
):
case (carry, x):
pass
case x:
carry = None
x = self.head(x, action=action, reward=reward, done=done, **kwargs)
return carry, x
[docs]
@nn.nowrap
def initialize_carry(self, input_shape: tuple) -> Carry:
key = jax.random.key(0)
return getattr(self.torso, "initialize_carry", lambda k, s: None)(
key, input_shape
)