memorax.utils.delayed_update

Contents

memorax.utils.delayed_update#

memorax.utils.delayed_update(new_tensors, old_tensors, new_opt_state, old_opt_state, steps, start_step)[source]#

Update all parameters only after a given timestep is reached.

Parameters:
Return type:

Union[Array, ndarray, bool, number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

Returns:

a copy of the model’s parameters that equals old_tensors before start_step and new_tensors from start_step onward.