memorax.algorithms.MAPPO#
- class memorax.algorithms.MAPPO[source]#
Bases:
objectMAPPO(cfg: memorax.algorithms.mappo.MAPPOConfig, env: gymnax.environments.environment.Environment, env_params: gymnax.environments.environment.EnvParams, actor_network: flax.linen.module.Module, critic_network: flax.linen.module.Module, actor_optimizer: optax._src.base.GradientTransformation, critic_optimizer: optax._src.base.GradientTransformation)
- cfg: MAPPOConfig#
- env: Environment#
- env_params: EnvParams#
- actor_network: Module#
- critic_network: Module#
- actor_optimizer: GradientTransformation#
- critic_optimizer: GradientTransformation#
- warmup(key, state, num_steps)[source]#
- Return type:
- Parameters:
key (Array)
state (MAPPOState)
num_steps (int)
- train(key, state, num_steps)[source]#
- Return type:
- Parameters:
key (Array)
state (MAPPOState)
num_steps (int)
- evaluate(key, state, num_steps)[source]#
- Return type:
- Parameters:
key (Array)
state (MAPPOState)
num_steps (int)
- __init__(cfg, env, env_params, actor_network, critic_network, actor_optimizer, critic_optimizer)#
- Parameters:
cfg (MAPPOConfig)
env (Environment)
env_params (EnvParams)
actor_network (Module)
critic_network (Module)
actor_optimizer (GradientTransformation)
critic_optimizer (GradientTransformation)
- Return type:
None