Source code for memorax.loggers.wandb

from collections import defaultdict
from dataclasses import field
from typing import Literal, Optional

import chex
from wandb.sdk.wandb_run import Run

import wandb
from memorax.utils.stats import naniqm

from .logger import BaseLogger, BaseLoggerState, PyTree


@chex.dataclass(frozen=True)
class WandbLoggerState(BaseLoggerState):
    runs: dict[int, Run]
    buffer: defaultdict[int, dict[str, PyTree]] = field(
        default_factory=lambda: defaultdict(dict)
    )


[docs] @chex.dataclass(frozen=True) class WandbLogger(BaseLogger[WandbLoggerState]): entity: Optional[str] = None project: Optional[str] = None name: Optional[str] = None group: Optional[str] = None mode: Literal["online", "disabled", "shared"] = "disabled" num_seeds: int = 1
[docs] def init(self, **kwargs) -> WandbLoggerState: cfg = kwargs.get("cfg", {}) runs = { seed: wandb.init( entity=self.entity, project=self.project, name=self.name, group=self.group, mode=self.mode, config={**cfg, "seed": cfg["seed"] + seed if "seed" in cfg else seed}, reinit="create_new", ) for seed in range(self.num_seeds) } return WandbLoggerState(runs=runs)
[docs] def log(self, state: WandbLoggerState, data: PyTree, step: int) -> WandbLoggerState: state.buffer[step].update(data) return state
[docs] def emit(self, state: WandbLoggerState) -> WandbLoggerState: for step, data in sorted(state.buffer.items()): for seed, run in state.runs.items(): run.log( { k: ( v[seed] if not ( isinstance(v, int) or isinstance(v, float) or v.ndim == 0 ) else v ) for k, v in data.items() }, step=step, ) state.buffer.clear() return state
[docs] def finish(self, state: WandbLoggerState) -> None: for run in state.runs.values(): run.finish()