Source code for memorax.utils.transition

from typing import Self

import jax
from flax import struct

from memorax.utils.axes import add_time_axis, remove_time_axis
from memorax.utils.timestep import Timestep
from memorax.utils.typing import PyTree


[docs] @struct.dataclass(frozen=True) class Transition: first: Timestep | None = None second: Timestep | None = None carry: PyTree | None = None aux: PyTree | None = None
[docs] def to_sequence(self) -> Self: return jax.tree.map(add_time_axis, self)
[docs] def from_sequence(self) -> Self: return jax.tree.map(remove_time_axis, self)