Skip to content

Instantly share code, notes, and snippets.

@leifdenby
Created February 23, 2026 19:35
Show Gist options
  • Select an option

  • Save leifdenby/613d81343c3a6243f65f01bbc0254c43 to your computer and use it in GitHub Desktop.

Select an option

Save leifdenby/613d81343c3a6243f65f01bbc0254c43 to your computer and use it in GitHub Desktop.
visualize anemoi checkpoints
#!/usr/bin/env -S uv run --script
# Usage:
# uv run ./script
# uv run ./script living-shiner.ckpt --modules-only
# uv run ./script --max-depth 6 --max-nodes 800 --out-dir checkpoint_viz_deep
#
# Outputs:
# - <out-dir>/module_tree.txt
# - <out-dir>/object_graph.dot
# - <out-dir>/object_graph.svg (if graphviz `dot` is installed)
#
# /// script
# requires-python = ">=3.11,<3.14"
# dependencies = [
# "torch",
# "torch-geometric==2.4.0",
# "anemoi-models @ git+https://github.com/ecmwf/anemoi-core.git@models-0.5.0#subdirectory=models",
# ]
# ///
import argparse
import json
import shutil
import subprocess
import sys
import zipfile
from collections import deque
from pathlib import Path
from typing import Any
import torch
from torch import nn
def install_torch_geometric_compat_aliases() -> None:
legacy = "torch_geometric.nn.conv.utils.inspector"
if legacy in sys.modules:
return
try:
from torch_geometric import inspector as inspector_mod
except Exception:
return
sys.modules[legacy] = inspector_mod
def load_checkpoint(path: Path, map_location: str, weights_only: bool) -> Any:
install_torch_geometric_compat_aliases()
return torch.load(path, map_location=map_location, weights_only=weights_only)
def read_module_versions_from_checkpoint(path: Path) -> dict[str, str]:
if not zipfile.is_zipfile(path):
return {}
with zipfile.ZipFile(path) as zf:
metadata_candidates = [n for n in zf.namelist() if n.endswith("anemoi-metadata/ai-models.json")]
if not metadata_candidates:
return {}
raw = zf.read(metadata_candidates[0])
data = json.loads(raw)
return data.get("provenance_training", {}).get("module_versions", {})
def print_package_versions(module_versions: dict[str, str]) -> None:
if not module_versions:
print("Package versions: not found in checkpoint metadata")
return
print(f"Package versions from checkpoint metadata ({len(module_versions)} total):")
prefixes = ("anemoi.", "torch", "pytorch_lightning", "lightning_")
shown = []
for name in sorted(module_versions):
if name.startswith(prefixes):
shown.append(name)
print(f" {name}=={module_versions[name]}")
if not shown:
print(" (No anemoi/torch/lightning packages found in module_versions)")
def module_tree_lines(module: nn.Module, max_depth: int) -> list[str]:
lines: list[str] = []
def walk(mod: nn.Module, name: str, depth: int) -> None:
indent = " " * depth
lines.append(f"{indent}{name}: {mod.__class__.__module__}.{mod.__class__.__name__}")
if depth >= max_depth:
return
for child_name, child in mod.named_children():
walk(child, child_name, depth + 1)
walk(module, "root", 0)
return lines
def short_label(obj: Any) -> str:
cls = f"{obj.__class__.__module__}.{obj.__class__.__name__}"
if isinstance(obj, nn.Module):
return f"{cls}\\n(nn.Module)"
if isinstance(obj, torch.Tensor):
return f"{cls}\\nshape={tuple(obj.shape)}"
if isinstance(obj, dict):
return f"{cls}\\nlen={len(obj)}"
if isinstance(obj, (list, tuple, set)):
return f"{cls}\\nlen={len(obj)}"
return cls
def node_fillcolor(obj: Any) -> str | None:
module_name = obj.__class__.__module__
if module_name.startswith("anemoi.models"):
return "#b7e4c7" # light green
if module_name.startswith("torch.nn"):
return "#dbeafe" # light blue
return None
def iter_edges(obj: Any, modules_only: bool) -> list[tuple[str, Any]]:
edges: list[tuple[str, Any]] = []
if modules_only:
if isinstance(obj, nn.Module):
for name, child in obj.named_children():
edges.append((f"module.{name}", child))
return edges
if isinstance(obj, nn.Module):
for name, child in obj.named_children():
edges.append((f"module.{name}", child))
for key, value in vars(obj).items():
if key in {"_modules", "_parameters", "_buffers"}:
continue
if key.startswith("_") and not key.startswith("_graph"):
continue
if isinstance(value, (str, int, float, bool, type(None))):
continue
edges.append((f"attr.{key}", value))
return edges
if isinstance(obj, dict):
for i, (k, v) in enumerate(obj.items()):
if i >= 20:
break
edges.append((f"key[{repr(k)[:24]}]", v))
return edges
if isinstance(obj, (list, tuple)):
for i, item in enumerate(obj[:20]):
edges.append((f"[{i}]", item))
return edges
if isinstance(obj, set):
for i, item in enumerate(list(obj)[:20]):
edges.append((f"set[{i}]", item))
return edges
if hasattr(obj, "__dict__"):
for key, value in vars(obj).items():
if key.startswith("_"):
continue
if isinstance(value, (str, int, float, bool, type(None))):
continue
edges.append((f"attr.{key}", value))
return edges
return edges
def write_object_graph_dot(
root: Any,
output_dot: Path,
max_nodes: int,
max_depth: int,
modules_only: bool,
) -> tuple[int, int]:
node_ids: dict[int, str] = {}
node_labels: dict[str, str] = {}
node_colors: dict[str, str] = {}
edge_lines: list[str] = []
q: deque[tuple[Any, int]] = deque([(root, 0)])
next_id = 0
while q and len(node_ids) < max_nodes:
obj, depth = q.popleft()
oid = id(obj)
if oid not in node_ids:
node_name = f"n{next_id}"
next_id += 1
node_ids[oid] = node_name
node_labels[node_name] = short_label(obj)
fill = node_fillcolor(obj)
if fill:
node_colors[node_name] = fill
src = node_ids[oid]
if depth >= max_depth:
continue
for label, child in iter_edges(obj, modules_only=modules_only):
if modules_only and not isinstance(child, nn.Module):
continue
child_id = id(child)
if child_id not in node_ids:
if len(node_ids) >= max_nodes:
break
node_name = f"n{next_id}"
next_id += 1
node_ids[child_id] = node_name
node_labels[node_name] = short_label(child)
fill = node_fillcolor(child)
if fill:
node_colors[node_name] = fill
q.append((child, depth + 1))
dst = node_ids[child_id]
safe_label = label.replace('"', '\\"')
edge_lines.append(f' {src} -> {dst} [label="{safe_label}"];')
with output_dot.open("w", encoding="utf-8") as f:
f.write("digraph checkpoint_object_graph {\n")
f.write(" rankdir=LR;\n")
f.write(' node [shape=box, fontsize=10, fontname="Helvetica"];\n')
for node_name, label in node_labels.items():
safe = label.replace('"', '\\"')
fill = node_colors.get(node_name)
if fill:
f.write(f' {node_name} [label="{safe}", style="filled", fillcolor="{fill}"];\n')
else:
f.write(f' {node_name} [label="{safe}"];\n')
for edge in edge_lines:
f.write(edge + "\n")
f.write("}\n")
return len(node_labels), len(edge_lines)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Visualize a loaded checkpoint object.")
parser.add_argument("checkpoint", nargs="?", default="living-shiner.ckpt")
parser.add_argument("--map-location", default="cpu")
parser.add_argument("--weights-only", action="store_true")
parser.add_argument("--out-dir", default="checkpoint_viz")
parser.add_argument("--max-depth", type=int, default=4)
parser.add_argument("--max-nodes", type=int, default=300)
parser.add_argument(
"--modules-only",
action="store_true",
help="Only include torch.nn.Module objects in the object graph diagram.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
ckpt = Path(args.checkpoint)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
if not ckpt.exists():
print(f"Checkpoint not found: {ckpt}")
return 1
try:
obj = load_checkpoint(ckpt, map_location=args.map_location, weights_only=args.weights_only)
except Exception as exc:
print(f"Failed to load checkpoint: {exc}")
return 2
print_package_versions(read_module_versions_from_checkpoint(ckpt))
print(f"Loaded: {ckpt}")
print(f"Root type: {obj.__class__.__module__}.{obj.__class__.__name__}")
if isinstance(obj, nn.Module):
tree_path = out_dir / "module_tree.txt"
tree_path.write_text("\n".join(module_tree_lines(obj, max_depth=args.max_depth)) + "\n", encoding="utf-8")
print(f"Wrote module tree: {tree_path}")
else:
print("Root is not an nn.Module; skipping module tree output.")
dot_path = out_dir / "object_graph.dot"
if args.modules_only and not isinstance(obj, nn.Module):
print("Root is not an nn.Module; cannot build a modules-only diagram.")
return 3
node_count, edge_count = write_object_graph_dot(
obj,
output_dot=dot_path,
max_nodes=args.max_nodes,
max_depth=args.max_depth,
modules_only=args.modules_only,
)
print(f"Wrote object graph DOT: {dot_path} ({node_count} nodes, {edge_count} edges)")
if shutil.which("dot"):
svg_path = out_dir / "object_graph.svg"
subprocess.run(["dot", "-Tsvg", str(dot_path), "-o", str(svg_path)], check=False)
if svg_path.exists():
print(f"Wrote SVG: {svg_path}")
else:
print("Graphviz 'dot' not found. Install graphviz to render DOT to SVG/PNG.")
return 0
if __name__ == "__main__":
raise SystemExit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment