memorax.utils.Transition#

class memorax.utils.Transition[source]#

Bases: object

Transition(obs: Optional[jax.jaxlib._jax.Array] = None, action: Optional[jax.jaxlib._jax.Array] = None, reward: Optional[jax.jaxlib._jax.Array] = None, done: Optional[jax.jaxlib._jax.Array] = None, info: Optional[dict] = None, prev_action: Optional[jax.jaxlib._jax.Array] = None, prev_reward: Optional[jax.jaxlib._jax.Array] = None, prev_done: Optional[jax.jaxlib._jax.Array] = None, next_obs: Optional[jax.jaxlib._jax.Array] = None, log_prob: Optional[jax.jaxlib._jax.Array] = None, value: Optional[jax.jaxlib._jax.Array] = None, env_state: Optional[jax.jaxlib._jax.Array] = None)

obs: Array | None = None#
action: Array | None = None#
reward: Array | None = None#
done: Array | None = None#
info: dict | None = None#
prev_action: Array | None = None#
prev_reward: Array | None = None#
prev_done: Array | None = None#
next_obs: Array | None = None#
log_prob: Array | None = None#
value: Array | None = None#
env_state: Array | None = None#
property num_episodes: Array#
property episode_lengths#
property episode_returns#
property losses#
property infos#
__init__(obs=None, action=None, reward=None, done=None, info=None, prev_action=None, prev_reward=None, prev_done=None, next_obs=None, log_prob=None, value=None, env_state=None)#
Parameters:
  • obs (Array | None)

  • action (Array | None)

  • reward (Array | None)

  • done (Array | None)

  • info (dict | None)

  • prev_action (Array | None)

  • prev_reward (Array | None)

  • prev_done (Array | None)

  • next_obs (Array | None)

  • log_prob (Array | None)

  • value (Array | None)

  • env_state (Array | None)

Return type:

None

replace(**updates)#

Returns a new object replacing the specified fields with new values.