Source code for memorax.utils.timestep

from typing import Self

import jax
from flax import struct

from memorax.utils.axes import add_feature_axis, add_time_axis, remove_time_axis
from memorax.utils.typing import Array


[docs] @struct.dataclass(frozen=True) class Timestep: obs: Array | None = None action: Array | None = None reward: Array | None = None done: Array | None = None def __iter__(self): yield self.obs yield self.done yield self.action yield add_feature_axis(self.reward)
[docs] def to_sequence(self) -> Self: return jax.tree.map(add_time_axis, self)
[docs] def from_sequence(self) -> Self: return jax.tree.map(remove_time_axis, self)