Source code for memorax.loggers.console
from collections import defaultdict
from dataclasses import field
from typing import DefaultDict
from flax import struct
from .logger import BaseLogger, BaseLoggerState, PyTree
@struct.dataclass(frozen=True)
class ConsoleLoggerState(BaseLoggerState):
buffer: DefaultDict[int, dict[str, PyTree]] = field(
default_factory=lambda: defaultdict(dict)
)
[docs]
@struct.dataclass(frozen=True)
class ConsoleLogger(BaseLogger[ConsoleLoggerState]):
[docs]
def init(self, cfg) -> ConsoleLoggerState:
return ConsoleLoggerState()
[docs]
def log(
self, state: ConsoleLoggerState, data: PyTree, step: int
) -> ConsoleLoggerState:
state.buffer[step].update(data)
return state
def _strong_line(self):
print("###############################################")
def _weak_line(self):
print("-----------------------------------------------")
[docs]
def emit(self, state: ConsoleLoggerState) -> ConsoleLoggerState:
for step, data in sorted(state.buffer.items()):
training = {
k.split("/")[-1]: v.mean().item()
for k, v in data.items()
if k.startswith("training")
}
evaluation = {
k.split("/")[-1]: v.mean().item()
for k, v in data.items()
if k.startswith("evaluation")
}
losses = {
k.split("/")[-1]: v.mean().item()
for k, v in data.items()
if k.startswith("loss")
}
if training:
self._strong_line()
print(f"TRAINING - {step:_}")
self._strong_line()
longest_key = len(max(training, key=len)) + 5
for k, v in training.items():
k += ":"
print(f"{k:<{longest_key}}: {v:.2f}")
self._weak_line()
if evaluation:
self._strong_line()
print(f"EVALUATION - {step:_}")
self._strong_line()
longest_key = len(max(evaluation, key=len)) + 5
for k, v in evaluation.items():
k += ":"
print(f"{k:<{longest_key}} {v:.2f}")
self._weak_line()
if losses:
self._strong_line()
print(f"LOSSES - {step:_}")
self._strong_line()
longest_key = len(max(losses, key=len)) + 5
for k, v in losses.items():
k += ":"
print(f"{k:<{longest_key}}: {v:.2f}")
self._weak_line()
state.buffer.clear()
return state