Source code for memorax.utils.transition

from typing import Optional

import jax
import jax.numpy as jnp
from flax import struct

from memorax.utils.typing import Array


[docs] @struct.dataclass(frozen=True) class Transition: obs: Optional[Array] = None action: Optional[Array] = None reward: Optional[Array] = None done: Optional[Array] = None info: Optional[dict] = None prev_action: Optional[Array] = None prev_reward: Optional[Array] = None prev_done: Optional[Array] = None next_obs: Optional[Array] = None log_prob: Optional[Array] = None value: Optional[Array] = None env_state: Optional[Array] = None @property def num_episodes(self) -> Array: assert self.done is not None return self.done.sum() @property def episode_lengths(self): assert self.done is not None def step(carry_len, done_t): curr_len = carry_len + 1 out = jnp.where(done_t, curr_len, jnp.zeros_like(curr_len)) next_len = jnp.where(done_t, jnp.zeros_like(curr_len), curr_len) return next_len, out init_len = jnp.zeros_like(self.done[0], dtype=jnp.int32) _, episode_lengths = jax.lax.scan(step, init_len, self.done) return jnp.where(self.done, episode_lengths, jnp.nan) @property def episode_returns(self): assert self.reward is not None assert self.done is not None def step(carry_sum, inp): r_t, d_t = inp s = carry_sum + r_t out = jnp.where(d_t, s, jnp.zeros_like(s)) next_s = jnp.where(d_t, jnp.zeros_like(s), s) return next_s, out init_sum = jnp.zeros_like(self.reward[0]) _, episode_returns = jax.lax.scan(step, init_sum, (self.reward, self.done)) return jnp.where(self.done, episode_returns, jnp.nan) @property def losses(self): assert self.info is not None return {k: v.mean() for k, v in self.info.items() if k.startswith("losses")} @property def infos(self): assert self.info is not None return {k: v.mean() for k, v in self.info.items() if not k.startswith("losses")}