memorax.algorithms.PPO#

class memorax.algorithms.PPO[source]#

Bases: object

PPO(cfg: memorax.algorithms.ppo.PPOConfig, env: gymnax.environments.environment.Environment, env_params: gymnax.environments.environment.EnvParams, actor: flax.linen.module.Module, critic: flax.linen.module.Module, actor_optimizer: optax._src.base.GradientTransformation, critic_optimizer: optax._src.base.GradientTransformation)

cfg: PPOConfig#
env: Environment#
env_params: EnvParams#
actor: Module#
critic: Module#
actor_optimizer: GradientTransformation#
critic_optimizer: GradientTransformation#
init(key)[source]#
warmup(key, state, num_steps)[source]#

No warmup needed for PPO

train(key, state, num_steps)[source]#
evaluate(key, state, num_steps, deterministic=True)[source]#
__init__(cfg, env, env_params, actor, critic, actor_optimizer, critic_optimizer)#
Parameters:
Return type:

None

replace(**updates)#

Returns a new object replacing the specified fields with new values.