Source code for memorax.loggers.logger
import atexit
from typing import Protocol, runtime_checkable
from memorax.utils.typing import PyTree
[docs]
@runtime_checkable
class Logger(Protocol):
[docs]
def log(self, data: PyTree, step: int, **kwargs) -> None: ...
[docs]
def finish(self) -> None: ...
[docs]
class MultiLogger:
[docs]
def __init__(self, loggers: list[Logger]):
self.loggers = loggers
atexit.register(self.finish)
[docs]
def log(self, data: PyTree, step: int, **kwargs) -> None:
for logger in self.loggers:
logger.log(data, step, **kwargs)
[docs]
def finish(self) -> None:
for logger in self.loggers:
logger.finish()