from collections import defaultdict
from typing import Any
import jax
import jax.numpy as jnp
from rich import box
from rich.console import Console
from rich.live import Live
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from memorax.utils.typing import PyTree
[docs]
class DashboardLogger:
[docs]
def __init__(
self, total_timesteps=0, refresh_per_second=10, summary=None, **kwargs
):
self.summary = summary or {}
self.console = Console()
self.progress = Progress(
TextColumn("[progress.description]{task.description}"),
SpinnerColumn(),
TimeElapsedColumn(),
BarColumn(bar_width=None),
TimeRemainingColumn(),
expand=True,
console=self.console,
)
self.progress_task = self.progress.add_task("Progress", total=total_timesteps)
dashboard = self.build_dashboard({}, 0, self.progress, self.progress_task)
self.live = Live(
dashboard,
console=self.console,
refresh_per_second=refresh_per_second,
transient=False,
)
self.live.start()
[docs]
def log(self, data: PyTree, step: int, **kwargs) -> None:
self.progress.update(self.progress_task, completed=int(step))
dashboard = self.build_dashboard(data, step, self.progress, self.progress_task)
self.live.update(dashboard, refresh=True)
[docs]
def finish(self) -> None:
self.live.stop()
self.console.show_cursor(True)
[docs]
def group(self, data: dict[str, PyTree]) -> dict[str, dict[str, Any]]:
data = {
"/".join(str(p.key) for p in path): leaf
for path, leaf in jax.tree_util.tree_leaves_with_path(data)
}
groups = defaultdict(dict)
for key, value in data.items():
if "/" in key:
prefix, name = key.split("/", 1)
groups[prefix][name] = value
else:
groups[""][key] = value
return dict(groups)
[docs]
def build_table(self, heading: str, metrics: dict[str, PyTree]) -> Table:
table = Table(box=None, expand=True)
table.add_column(heading, justify="left", width=20, style="yellow")
table.add_column("Value", justify="right", width=10, style="green")
for name, value in metrics.items():
mean, std = jnp.mean(value), jnp.std(value)
fmt = ".3e" if (0 < abs(mean) < 0.001 or abs(mean) >= 10000) else ".3f"
value_str = f"{mean:{fmt}} ± {std:{fmt}}" if std != 0 else f"{mean:{fmt}}"
table.add_row(name, value_str)
return table
[docs]
def build_dashboard(
self, data: dict[str, PyTree], step: int, progress: Progress, task: Any
) -> Table:
dashboard = Table(
box=box.ROUNDED,
expand=True,
show_header=False,
border_style="white",
)
dynamic_summary = {
k.split("/", 1)[1]: v for k, v in data.items() if k.startswith("summary/")
}
items = [*self.summary.items(), *dynamic_summary.items()]
if data:
items.append(("Step", f"{int(step):_}"))
left = Table(box=None, expand=True)
left.add_column("Summary", justify="left", width=16, style="white")
left.add_column("Value", justify="right", width=8, style="white")
right = Table(box=None, expand=True)
right.add_column("Summary", justify="left", width=16, style="white")
right.add_column("Value", justify="right", width=8, style="white")
for i, (key, value) in enumerate(items):
table = left if i % 2 == 0 else right
value_str = f"{value:_}" if isinstance(value, int) else f"{value}"
table.add_row(key, value_str, style="white")
summary_row = Table(box=None, expand=True, pad_edge=False)
summary_row.add_row(left, right)
dashboard.add_row(summary_row)
groups = self.group(data)
groups.pop("summary", None)
group_names = list(groups.keys())
for i in range(0, len(group_names), 2):
pair = group_names[i : i + 2]
tables = [self.build_table(name, groups[name]) for name in pair]
row = Table(box=None, expand=True, pad_edge=False)
row.add_row(*tables)
dashboard.add_row(row)
dashboard.add_row("")
progress.update(task, completed=int(step))
dashboard.add_row(progress)
return dashboard