memorax.algorithms.GradientPPOState#

class memorax.algorithms.GradientPPOState[source]#

Bases: object

GradientPPOState(step: int, update_step: int, timestep: memorax.utils.timestep.Timestep, env_state: gymnax.environments.environment.EnvState, actor_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], actor_optimizer_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], actor_carry: jax.Array, critic_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], critic_optimizer_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], critic_carry: jax.Array, h_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], h_optimizer_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], h_carry: jax.Array)

step: int#
update_step: int#
timestep: Timestep#
env_state: EnvState#
actor_params: FrozenDict[str, Any]#
actor_optimizer_state: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
actor_carry: Array#
critic_params: FrozenDict[str, Any]#
critic_optimizer_state: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
critic_carry: Array#
h_params: FrozenDict[str, Any]#
h_optimizer_state: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
h_carry: Array#
__init__(step, update_step, timestep, env_state, actor_params, actor_optimizer_state, actor_carry, critic_params, critic_optimizer_state, critic_carry, h_params, h_optimizer_state, h_carry)#
Parameters:
Return type:

None

replace(**updates)#

Returns a new object replacing the specified fields with new values.