Source code for memorax.loggers.checkpoint
import jax
import numpy as np
import orbax.checkpoint.experimental.v1 as ocp
from memorax.utils.decorators import callback
from memorax.utils.typing import PyTree
[docs]
class CheckpointLogger:
[docs]
def __init__(self, directory="checkpoints", max_to_keep=None, **kwargs):
preservation_policy = None
if max_to_keep is not None:
preservation_policy = ocp.training.preservation_policies.LatestN(
max_to_keep
)
self.checkpointer = ocp.training.Checkpointer(
directory,
preservation_policy=preservation_policy,
)
[docs]
@callback
def log(self, data: PyTree, step: int, train_state: PyTree | None = None, **kwargs):
if train_state is None:
return
train_state = jax.tree.map(lambda value: np.asarray(value), train_state)
self.checkpointer.save_pytree(int(step), train_state, force=True)
[docs]
def finish(self) -> None:
self.checkpointer.close()