memorax.algorithms.SACState#

class memorax.algorithms.SACState[source]#

Bases: object

SACState(step: int, env_state: gymnax.environments.environment.EnvState, buffer_state: flashbax.buffers.trajectory_buffer.TrajectoryBufferState, actor_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], critic_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], critic_target_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], alpha_params: flax.core.frozen_dict.FrozenDict[str, typing.Any], actor_optimizer_state: Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], critic_optimizer_state: Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], alpha_optimizer_state: Union[jax.jaxlib._jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef(‘ArrayTree’)], Mapping[Any, ForwardRef(‘ArrayTree’)]], obs: jax.jaxlib._jax.Array, done: jax.jaxlib._jax.Array, actor_hidden_state: jax.jaxlib._jax.Array)

step: int#
env_state: EnvState#
buffer_state: TrajectoryBufferState#
actor_params: FrozenDict[str, Any]#
critic_params: FrozenDict[str, Any]#
critic_target_params: FrozenDict[str, Any]#
alpha_params: FrozenDict[str, Any]#
actor_optimizer_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
critic_optimizer_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
alpha_optimizer_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]#
obs: Array#
done: Array#
actor_hidden_state: Array#
__init__(step, env_state, buffer_state, actor_params, critic_params, critic_target_params, alpha_params, actor_optimizer_state, critic_optimizer_state, alpha_optimizer_state, obs, done, actor_hidden_state)#
Parameters:
Return type:

None

replace(**updates)#

Returns a new object replacing the specified fields with new values.