Source code for memorax.loggers.tensorboard

from collections import defaultdict
from dataclasses import field

from flax import struct
from tensorboardX import SummaryWriter

from .logger import BaseLogger, BaseLoggerState, PyTree


@struct.dataclass(frozen=True)
class TensorBoardLoggerState(BaseLoggerState):
    writers: dict[int, SummaryWriter]
    buffer: defaultdict[int, dict[str, PyTree]] = field(
        default_factory=lambda: defaultdict(dict)
    )


[docs] @struct.dataclass(frozen=True) class TensorBoardLogger(BaseLogger[TensorBoardLoggerState]): log_dir: str = "tensorboard"
[docs] def init(self, cfg: dict) -> TensorBoardLoggerState: writers = { seed: SummaryWriter(log_dir=self.log_dir) for seed in range(cfg["num_seeds"]) } return TensorBoardLoggerState(writers=writers)
[docs] def log( self, state: TensorBoardLoggerState, data: PyTree, step: int ) -> TensorBoardLoggerState: state.buffer[step].update(data) return state
[docs] def emit(self, state: TensorBoardLoggerState) -> TensorBoardLoggerState: for step, data in sorted(state.buffer.items()): for seed, writer in state.writers.items(): for metric, value in data.items(): writer.add_scalar( metric, value[seed] if metric != "SPS" else value, step, ) state.buffer.clear() return state
[docs] def finish(self, state: TensorBoardLoggerState) -> None: for writer in state.writers.values(): writer.close()