Source code for memorax.utils.gae

import jax
import jax.numpy as jnp


[docs] @jax.jit def generalized_advantage_estimatation( gamma: float, gae_lambda: float, final_value: jax.Array, transitions ): """Compute Generalized Advantage Estimates (GAE) for a trajectory.""" def f(carry, transition): advantage, value = carry delta = ( transition.reward + gamma * value * (1 - transition.done) - transition.value ) advantage = delta + gamma * gae_lambda * (1 - transition.done) * advantage return (advantage, transition.value), advantage _, advantages = jax.lax.scan( f, (jnp.zeros_like(final_value), final_value), transitions, reverse=True, unroll=16, ) returns = advantages + transitions.value return advantages, returns