Source code for memorax.networks.feature_extractor

from typing import Callable

import flax.linen as nn
import jax.numpy as jnp

from memorax.utils.typing import Array


[docs] class FeatureExtractor(nn.Module): observation_extractor: Callable action_extractor: Callable | None = None reward_extractor: Callable | None = None done_extractor: Callable | None = None
[docs] def extract( self, embeddings: dict, key: str, extractor: Callable | None, x: Array | None = None, ) -> None: if extractor is not None and x is not None: embeddings[key] = extractor(x)
[docs] @nn.compact def __call__( self, observation: Array, action: Array, reward: Array, done: Array, **kwargs, ) -> tuple[Array, dict]: embeddings = {"observation_embedding": self.observation_extractor(observation)} self.extract(embeddings, "action_embedding", self.action_extractor, action) self.extract(embeddings, "reward_embedding", self.reward_extractor, reward) self.extract( embeddings, "done_embedding", self.done_extractor, done.astype(jnp.int32) ) x = jnp.concatenate([*embeddings.values()], axis=-1) return x, embeddings