Source code for memorax.loggers.neptune
from collections import defaultdict
from dataclasses import field
from typing import Optional
from flax import struct
from neptune_scale import Run
from .logger import BaseLogger, BaseLoggerState, PyTree
@struct.dataclass(frozen=True)
class NeptuneLoggerState(BaseLoggerState):
runs: dict[int, Run]
buffer: defaultdict[int, dict[str, PyTree]] = field(
default_factory=lambda: defaultdict(dict)
)
[docs]
@struct.dataclass(frozen=True)
class NeptuneLogger(BaseLogger[NeptuneLoggerState]):
workspace: Optional[str] = None
project: Optional[str] = None
mode: str = "disabled"
[docs]
def init(self, cfg: dict) -> NeptuneLoggerState:
runs = {
seed: Run(project=f"{self.workspace}/{self.project}", mode=self.mode)
for seed in range(cfg["num_seeds"])
}
for run in runs.values():
run.log_configs(cfg)
return NeptuneLoggerState(runs=runs)
[docs]
def log(
self, state: NeptuneLoggerState, data: PyTree, step: int
) -> NeptuneLoggerState:
state.buffer[step].update(data)
return state
[docs]
def emit(self, state: NeptuneLoggerState) -> NeptuneLoggerState:
for step, data in sorted(state.buffer.items()):
for seed, run in state.runs.items():
run.log_metrics(
{k: v[seed] if k != "SPS" else v for k, v in data.items()},
step=step,
)
state.buffer.clear()
return state
[docs]
def finish(self, state: NeptuneLoggerState) -> None:
for run in state.runs.values():
run.close()