Source code for memorax.loggers.logger

import atexit
from typing import Protocol, runtime_checkable

from memorax.utils.typing import PyTree


[docs] @runtime_checkable class Logger(Protocol):
[docs] def log(self, data: PyTree, step: int, **kwargs) -> None: ...
[docs] def finish(self) -> None: ...
[docs] class MultiLogger:
[docs] def __init__(self, loggers: list[Logger]): self.loggers = loggers atexit.register(self.finish)
[docs] def log(self, data: PyTree, step: int, **kwargs) -> None: for logger in self.loggers: logger.log(data, step, **kwargs)
[docs] def finish(self) -> None: for logger in self.loggers: logger.finish()