memorax.algorithms.GradientPPO#

class memorax.algorithms.GradientPPO[source]#

Bases: object

GradientPPO(cfg: memorax.algorithms.gradient_ppo.GradientPPOConfig, env: gymnax.environments.environment.Environment, env_params: gymnax.environments.environment.EnvParams, actor_network: flax.linen.module.Module, critic_network: flax.linen.module.Module, h_network: flax.linen.module.Module, actor_optimizer: optax._src.base.GradientTransformation, critic_optimizer: optax._src.base.GradientTransformation, h_optimizer: optax._src.base.GradientTransformation)

cfg: GradientPPOConfig#
env: Environment#
env_params: EnvParams#
actor_network: Module#
critic_network: Module#
h_network: Module#
actor_optimizer: GradientTransformation#
critic_optimizer: GradientTransformation#
h_optimizer: GradientTransformation#
init(key)[source]#
Return type:

GradientPPOState

Parameters:

key (Array)

warmup(key, state, num_steps)[source]#
Return type:

GradientPPOState

Parameters:
train(key, state, num_steps)[source]#
Return type:

GradientPPOState

Parameters:
evaluate(key, state, num_steps)[source]#
Return type:

GradientPPOState

Parameters:
__init__(cfg, env, env_params, actor_network, critic_network, h_network, actor_optimizer, critic_optimizer, h_optimizer)#
Parameters:
Return type:

None