memorax.algorithms.MAPPO#

class memorax.algorithms.MAPPO[source]#

Bases: object

MAPPO(cfg: memorax.algorithms.mappo.MAPPOConfig, env: gymnax.environments.environment.Environment, env_params: gymnax.environments.environment.EnvParams, actor_network: flax.linen.module.Module, critic_network: flax.linen.module.Module, actor_optimizer: optax._src.base.GradientTransformation, critic_optimizer: optax._src.base.GradientTransformation)

cfg: MAPPOConfig#
env: Environment#
env_params: EnvParams#
actor_network: Module#
critic_network: Module#
actor_optimizer: GradientTransformation#
critic_optimizer: GradientTransformation#
init(key)[source]#
Return type:

MAPPOState

Parameters:

key (Array)

warmup(key, state, num_steps)[source]#
Return type:

MAPPOState

Parameters:
train(key, state, num_steps)[source]#
Return type:

MAPPOState

Parameters:
evaluate(key, state, num_steps)[source]#
Return type:

MAPPOState

Parameters:
__init__(cfg, env, env_params, actor_network, critic_network, actor_optimizer, critic_optimizer)#
Parameters:
Return type:

None