Source code for memorax.loggers.file

from collections import defaultdict
from dataclasses import field
from datetime import datetime
from pathlib import Path
from typing import DefaultDict

from flax import struct

from .logger import BaseLogger, BaseLoggerState, PyTree


@struct.dataclass(frozen=True)
class FileLoggerState(BaseLoggerState):
    base: Path
    paths: dict[int, Path]
    buffer: DefaultDict[int, dict[str, PyTree]] = field(
        default_factory=lambda: defaultdict(dict)
    )


[docs] @struct.dataclass(frozen=True) class FileLogger(BaseLogger[FileLoggerState]): algorithm: str environment: str directory: str = "logs"
[docs] def init(self, cfg: dict) -> FileLoggerState: if "actor" in cfg["algorithm"]: cell = cfg["algorithm"]["actor"]["torso"]["_target_"] if "RNN" in cell: cell = cfg["algorithm"]["actor"]["torso"]["cell"]["_target_"] else: cell = cfg["algorithm"]["torso"]["_target_"] cell = cell.split(".")[-1] if "parameters" in cfg["environment"]: params = "" for key, param in cfg["environment"]["parameters"].items(): if key == "max_steps_in_episode": continue params += f"{param}/" if params: params = params[:-1] base_path = ( Path(self.directory) / self.environment / params / self.algorithm / cell / f"{datetime.now():%Y%m%d-%H%M%S}" ) else: base_path = ( Path(self.directory) / self.environment / self.algorithm / cell / f"{datetime.now():%Y%m%d-%H%M%S}" ) base_path.mkdir(exist_ok=True, parents=True) paths = {seed: (base_path / str(seed)) for seed in range(cfg["num_seeds"])} for _, path in paths.items(): path.mkdir(exist_ok=True, parents=True) return FileLoggerState(base=base_path, paths=paths)
[docs] def log(self, state: FileLoggerState, data: PyTree, step: int) -> FileLoggerState: state.buffer[step].update(data) return state
[docs] def emit(self, state: FileLoggerState) -> FileLoggerState: for step, data in sorted(state.buffer.items()): for seed, path in state.paths.items(): for metric, value in { k: ( v[seed] if not ( isinstance(v, int) or isinstance(v, float) or v.ndim == 0 ) else v ) for k, v in data.items() }.items(): metric_path = (path / f"{metric}.csv").resolve() metric_path.parent.mkdir(exist_ok=True, parents=True) if not metric_path.exists(): with metric_path.open("a") as f: f.write(f"step,{metric}\n") with metric_path.open("a") as f: f.write(f"{step},{value}\n") state.buffer.clear() return state
[docs] def finish(self, state: FileLoggerState) -> None: pass