memorax.algorithms.GradientPPO#
- class memorax.algorithms.GradientPPO[source]#
Bases:
objectGradientPPO(cfg: memorax.algorithms.gradient_ppo.GradientPPOConfig, env: gymnax.environments.environment.Environment, env_params: gymnax.environments.environment.EnvParams, actor_network: flax.linen.module.Module, critic_network: flax.linen.module.Module, h_network: flax.linen.module.Module, actor_optimizer: optax._src.base.GradientTransformation, critic_optimizer: optax._src.base.GradientTransformation, h_optimizer: optax._src.base.GradientTransformation)
- cfg: GradientPPOConfig#
- env: Environment#
- env_params: EnvParams#
- actor_network: Module#
- critic_network: Module#
- h_network: Module#
- actor_optimizer: GradientTransformation#
- critic_optimizer: GradientTransformation#
- h_optimizer: GradientTransformation#
- warmup(key, state, num_steps)[source]#
- Return type:
- Parameters:
key (Array)
state (GradientPPOState)
num_steps (int)
- train(key, state, num_steps)[source]#
- Return type:
- Parameters:
key (Array)
state (GradientPPOState)
num_steps (int)
- evaluate(key, state, num_steps)[source]#
- Return type:
- Parameters:
key (Array)
state (GradientPPOState)
num_steps (int)
- __init__(cfg, env, env_params, actor_network, critic_network, h_network, actor_optimizer, critic_optimizer, h_optimizer)#
- Parameters:
cfg (GradientPPOConfig)
env (Environment)
env_params (EnvParams)
actor_network (Module)
critic_network (Module)
h_network (Module)
actor_optimizer (GradientTransformation)
critic_optimizer (GradientTransformation)
h_optimizer (GradientTransformation)
- Return type:
None