memorax.algorithms.DQNState#

class memorax.algorithms.DQNState[source]#

Bases: object

DQNState(step: int, obs: jax.jaxlib._jax.Array, done: jax.jaxlib._jax.Array, hidden_state: tuple, env_state: gymnax.environments.environment.EnvState, params: flax.core.frozen_dict.FrozenDict[str, typing.Any], target_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], optimizer_state: Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], buffer_state: flashbax.buffers.trajectory_buffer.TrajectoryBufferState)

step: int#
obs: Array#
done: Array#
hidden_state: tuple#
env_state: EnvState#
params: FrozenDict[str, Any]#
target_params: FrozenDict[str, Any]#
optimizer_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
buffer_state: TrajectoryBufferState#
__init__(step, obs, done, hidden_state, env_state, params, target_params, optimizer_state, buffer_state)#
Parameters:
Return type:

None

replace(**updates)#

Returns a new object replacing the specified fields with new values.