Source code for memorax.utils.timestep
import jax
import jax.numpy as jnp
from flax import struct
from memorax.networks.sequence_models.utils import (add_time_axis,
remove_time_axis)
from memorax.utils.typing import Array