memorax.utils.delayed_update#
- memorax.utils.delayed_update(new_tensors, old_tensors, new_opt_state, old_opt_state, steps, start_step)[source]#
Update all parameters only after a given timestep is reached.
- Parameters:
new_tensors (
Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]) – the latest value of the tensors.old_tensors (
Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]) – a copy of the model’s parameters that remains unchanged until steps >= start_step.steps (
Array) – current number of update steps on the “online” network.start_step (
int) – timestep at which the copy begins mirroring new_tensors.new_opt_state (Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree])
old_opt_state (Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree])
- Return type:
Union[Array,ndarray,bool,number,Iterable[ArrayTree],Mapping[Any, ArrayTree]]- Returns:
a copy of the model’s parameters that equals old_tensors before start_step and new_tensors from start_step onward.