memorax.algorithms.R2D2State#
- class memorax.algorithms.R2D2State[source]#
Bases:
objectR2D2State(step: int, update_step: int, timestep: memorax.utils.timestep.Timestep, carry: tuple, env_state: gymnax.environments.environment.EnvState, params: flax.core.frozen_dict.FrozenDict[str, typing.Any], target_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], optimizer_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], buffer_state: flashbax.buffers.trajectory_buffer.TrajectoryBufferState)
- env_state: EnvState#
- params: FrozenDict[str, Any]#
- target_params: FrozenDict[str, Any]#
- optimizer_state: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
- buffer_state: TrajectoryBufferState#
- __init__(step, update_step, timestep, carry, env_state, params, target_params, optimizer_state, buffer_state)#
- Parameters:
step (int)
update_step (int)
timestep (Timestep)
carry (tuple)
env_state (EnvState)
params (FrozenDict[str, Any])
target_params (FrozenDict[str, Any])
optimizer_state (Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree])
buffer_state (TrajectoryBufferState)
- Return type:
None
- replace(**updates)#
Returns a new object replacing the specified fields with new values.