memorax.algorithms.SAC#

class memorax.algorithms.SAC[source]#

Bases: object

SAC(cfg: memorax.algorithms.sac.SACConfig, env: gymnax.environments.environment.Environment, env_params: gymnax.environments.environment.EnvParams, actor_network: flax.linen.module.Module, critic_network: flax.linen.module.Module, alpha_network: flax.linen.module.Module, actor_optimizer: optax._src.base.GradientTransformation, critic_optimizer: optax._src.base.GradientTransformation, alpha_optimizer: optax._src.base.GradientTransformation, buffer: flashbax.buffers.trajectory_buffer.TrajectoryBuffer)

cfg: SACConfig#
env: Environment#
env_params: EnvParams#
actor_network: Module#
critic_network: Module#
alpha_network: Module#
actor_optimizer: GradientTransformation#
critic_optimizer: GradientTransformation#
alpha_optimizer: GradientTransformation#
buffer: TrajectoryBuffer#
init(key)[source]#
Parameters:

key (Array)

warmup(key, state, num_steps)[source]#
Return type:

SACState

Parameters:
train(key, state, num_steps)[source]#
Parameters:
evaluate(key, state, num_steps)[source]#
Return type:

SACState

Parameters:
__init__(cfg, env, env_params, actor_network, critic_network, alpha_network, actor_optimizer, critic_optimizer, alpha_optimizer, buffer)#
Parameters:
Return type:

None