Source code for memorax.loggers.tensorboard

import jax
from tensorboardX import SummaryWriter

from memorax.utils.axes import ensure_axis
from memorax.utils.typing import PyTree


[docs] class TensorBoardLogger:
[docs] def __init__(self, directory="tensorboard", num_seeds=1, **kwargs): self.writers = { seed: SummaryWriter(log_dir=directory) for seed in range(num_seeds) }
[docs] def log(self, data: PyTree, step: int, **kwargs) -> None: num_seeds = len(self.writers) data = jax.tree.map(lambda v: ensure_axis(v, num_seeds), data) for seed, writer in self.writers.items(): for metric, value in data.items(): writer.add_scalar(metric, value[seed], step)
[docs] def finish(self) -> None: for writer in self.writers.values(): writer.close()