Source code for memorax.utils.update
import jax
from optax import incremental_update, periodic_update
from optax._src import base
from memorax.utils.typing import Array
[docs]
def periodic_incremental_update(
new_tensors: base.Params,
old_tensors: base.Params,
steps: Array,
update_period: int,
step_size: float,
) -> base.Params:
"""Periodically perform Polyak-style incremental updates.
Combines the ideas of `periodic_update` and `incremental_update`: every
`update_period` steps, the slow copy is updated with an exponential moving
average of the fast parameters; otherwise it stays unchanged.
Args:
new_tensors: the latest value of the tensors.
old_tensors: a slow copy of the model's parameters.
steps: current number of update steps on the "online" network.
update_period: every how many steps to refresh the slow copy.
step_size: Polyak averaging factor used when the refresh occurs.
Returns:
a slow copy of the model's parameters that is incrementally updated every
`update_period` steps:
`step_size * new_tensors + (1 - step_size) * old_tensors`.
"""
return periodic_update(
incremental_update(new_tensors, old_tensors, step_size),
old_tensors,
steps,
update_period,
)
[docs]
def delayed_update(
new_tensors: base.Params,
old_tensors: base.Params,
steps: Array,
start_step: int,
) -> base.Params:
"""Update all parameters only after a given timestep is reached.
Args:
new_tensors: the latest value of the tensors.
old_tensors: a copy of the model's parameters that remains unchanged
until `steps >= start_step`.
steps: current number of update steps on the "online" network.
start_step: timestep at which the copy begins mirroring `new_tensors`.
Returns:
a copy of the model's parameters that equals `old_tensors` before
`start_step` and `new_tensors` from `start_step` onward.
"""
return jax.lax.cond(
steps >= start_step,
lambda: new_tensors,
lambda: old_tensors,
)