Source code for memorax.loggers.wandb

import jax
import wandb

from memorax.utils.axes import ensure_axis
from memorax.utils.typing import PyTree


[docs] class WandbLogger:
[docs] def __init__( self, entity=None, project=None, name=None, group=None, mode="disabled", cfg=None, seed=0, num_seeds=1, **kwargs, ): self.runs = { i: wandb.init( entity=entity, project=project, name=name, group=group, mode=mode, config={**(cfg or {}), "seed": seed + i}, reinit="create_new", ) for i in range(num_seeds) }
[docs] def log(self, data: PyTree, step: int, **kwargs) -> None: num_seeds = len(self.runs) data = { "/".join(str(p.key) for p in path): ensure_axis(leaf, num_seeds) for path, leaf in jax.tree_util.tree_leaves_with_path(data) } for seed, run in self.runs.items(): run.log({k: v[seed] for k, v in data.items()}, step=step)
[docs] def finish(self) -> None: for run in self.runs.values(): run.finish()