memorax.algorithms.SACState#
- class memorax.algorithms.SACState[source]#
Bases:
objectSACState(step: int, update_step: int, timestep: memorax.utils.timestep.Timestep, env_state: gymnax.environments.environment.EnvState, buffer_state: flashbax.buffers.trajectory_buffer.TrajectoryBufferState, actor_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], critic_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], critic_target_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], alpha_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], actor_optimizer_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], critic_optimizer_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], alpha_optimizer_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], actor_carry: jax.Array, critic_carry: jax.Array)
- env_state: EnvState#
- buffer_state: TrajectoryBufferState#
- actor_params: FrozenDict[str, Any]#
- critic_params: FrozenDict[str, Any]#
- critic_target_params: FrozenDict[str, Any]#
- alpha_params: FrozenDict[str, Any]#
- actor_optimizer_state: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
- critic_optimizer_state: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
- alpha_optimizer_state: Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
- __init__(step, update_step, timestep, env_state, buffer_state, actor_params, critic_params, critic_target_params, alpha_params, actor_optimizer_state, critic_optimizer_state, alpha_optimizer_state, actor_carry, critic_carry)#
- Parameters:
step (int)
update_step (int)
timestep (Timestep)
env_state (EnvState)
buffer_state (TrajectoryBufferState)
actor_params (FrozenDict[str, Any])
critic_params (FrozenDict[str, Any])
critic_target_params (FrozenDict[str, Any])
alpha_params (FrozenDict[str, Any])
actor_optimizer_state (Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree])
critic_optimizer_state (Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree])
alpha_optimizer_state (Array | ndarray | bool | number | bool | int | float | complex | Iterable[ArrayTree] | Mapping[Any, ArrayTree])
actor_carry (Array)
critic_carry (Array)
- Return type:
None
- replace(**updates)#
Returns a new object replacing the specified fields with new values.