Source code for memorax.networks.feature_extractor
from typing import Optional
import flax.linen as nn
import jax.numpy as jnp
[docs]
class FeatureExtractor(nn.Module):
observation_extractor: nn.Module
action_extractor: Optional[nn.Module] = None
reward_extractor: Optional[nn.Module] = None
done_extractor: Optional[nn.Module] = None
[docs]
def extract(
self,
features: list,
extractor: Optional[nn.Module],
x: Optional[jnp.ndarray] = None,
):
if extractor is not None and x is not None:
features.append(extractor(x))
[docs]
@nn.compact
def __call__(
self,
observation: jnp.ndarray,
action: jnp.ndarray,
reward: jnp.ndarray,
done: jnp.ndarray,
**kwargs,
):
features = [self.observation_extractor(observation)]
self.extract(features, self.action_extractor, action)
self.extract(features, self.reward_extractor, reward)
self.extract(features, self.done_extractor, done.astype(jnp.int32))
features = jnp.concatenate(features, axis=-1)
return features