Created
February 23, 2026 19:35
-
-
Save leifdenby/613d81343c3a6243f65f01bbc0254c43 to your computer and use it in GitHub Desktop.
visualize anemoi checkpoints
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env -S uv run --script | |
| # Usage: | |
| # uv run ./script | |
| # uv run ./script living-shiner.ckpt --modules-only | |
| # uv run ./script --max-depth 6 --max-nodes 800 --out-dir checkpoint_viz_deep | |
| # | |
| # Outputs: | |
| # - <out-dir>/module_tree.txt | |
| # - <out-dir>/object_graph.dot | |
| # - <out-dir>/object_graph.svg (if graphviz `dot` is installed) | |
| # | |
| # /// script | |
| # requires-python = ">=3.11,<3.14" | |
| # dependencies = [ | |
| # "torch", | |
| # "torch-geometric==2.4.0", | |
| # "anemoi-models @ git+https://github.com/ecmwf/anemoi-core.git@models-0.5.0#subdirectory=models", | |
| # ] | |
| # /// | |
| import argparse | |
| import json | |
| import shutil | |
| import subprocess | |
| import sys | |
| import zipfile | |
| from collections import deque | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from torch import nn | |
| def install_torch_geometric_compat_aliases() -> None: | |
| legacy = "torch_geometric.nn.conv.utils.inspector" | |
| if legacy in sys.modules: | |
| return | |
| try: | |
| from torch_geometric import inspector as inspector_mod | |
| except Exception: | |
| return | |
| sys.modules[legacy] = inspector_mod | |
| def load_checkpoint(path: Path, map_location: str, weights_only: bool) -> Any: | |
| install_torch_geometric_compat_aliases() | |
| return torch.load(path, map_location=map_location, weights_only=weights_only) | |
| def read_module_versions_from_checkpoint(path: Path) -> dict[str, str]: | |
| if not zipfile.is_zipfile(path): | |
| return {} | |
| with zipfile.ZipFile(path) as zf: | |
| metadata_candidates = [n for n in zf.namelist() if n.endswith("anemoi-metadata/ai-models.json")] | |
| if not metadata_candidates: | |
| return {} | |
| raw = zf.read(metadata_candidates[0]) | |
| data = json.loads(raw) | |
| return data.get("provenance_training", {}).get("module_versions", {}) | |
| def print_package_versions(module_versions: dict[str, str]) -> None: | |
| if not module_versions: | |
| print("Package versions: not found in checkpoint metadata") | |
| return | |
| print(f"Package versions from checkpoint metadata ({len(module_versions)} total):") | |
| prefixes = ("anemoi.", "torch", "pytorch_lightning", "lightning_") | |
| shown = [] | |
| for name in sorted(module_versions): | |
| if name.startswith(prefixes): | |
| shown.append(name) | |
| print(f" {name}=={module_versions[name]}") | |
| if not shown: | |
| print(" (No anemoi/torch/lightning packages found in module_versions)") | |
| def module_tree_lines(module: nn.Module, max_depth: int) -> list[str]: | |
| lines: list[str] = [] | |
| def walk(mod: nn.Module, name: str, depth: int) -> None: | |
| indent = " " * depth | |
| lines.append(f"{indent}{name}: {mod.__class__.__module__}.{mod.__class__.__name__}") | |
| if depth >= max_depth: | |
| return | |
| for child_name, child in mod.named_children(): | |
| walk(child, child_name, depth + 1) | |
| walk(module, "root", 0) | |
| return lines | |
| def short_label(obj: Any) -> str: | |
| cls = f"{obj.__class__.__module__}.{obj.__class__.__name__}" | |
| if isinstance(obj, nn.Module): | |
| return f"{cls}\\n(nn.Module)" | |
| if isinstance(obj, torch.Tensor): | |
| return f"{cls}\\nshape={tuple(obj.shape)}" | |
| if isinstance(obj, dict): | |
| return f"{cls}\\nlen={len(obj)}" | |
| if isinstance(obj, (list, tuple, set)): | |
| return f"{cls}\\nlen={len(obj)}" | |
| return cls | |
| def node_fillcolor(obj: Any) -> str | None: | |
| module_name = obj.__class__.__module__ | |
| if module_name.startswith("anemoi.models"): | |
| return "#b7e4c7" # light green | |
| if module_name.startswith("torch.nn"): | |
| return "#dbeafe" # light blue | |
| return None | |
| def iter_edges(obj: Any, modules_only: bool) -> list[tuple[str, Any]]: | |
| edges: list[tuple[str, Any]] = [] | |
| if modules_only: | |
| if isinstance(obj, nn.Module): | |
| for name, child in obj.named_children(): | |
| edges.append((f"module.{name}", child)) | |
| return edges | |
| if isinstance(obj, nn.Module): | |
| for name, child in obj.named_children(): | |
| edges.append((f"module.{name}", child)) | |
| for key, value in vars(obj).items(): | |
| if key in {"_modules", "_parameters", "_buffers"}: | |
| continue | |
| if key.startswith("_") and not key.startswith("_graph"): | |
| continue | |
| if isinstance(value, (str, int, float, bool, type(None))): | |
| continue | |
| edges.append((f"attr.{key}", value)) | |
| return edges | |
| if isinstance(obj, dict): | |
| for i, (k, v) in enumerate(obj.items()): | |
| if i >= 20: | |
| break | |
| edges.append((f"key[{repr(k)[:24]}]", v)) | |
| return edges | |
| if isinstance(obj, (list, tuple)): | |
| for i, item in enumerate(obj[:20]): | |
| edges.append((f"[{i}]", item)) | |
| return edges | |
| if isinstance(obj, set): | |
| for i, item in enumerate(list(obj)[:20]): | |
| edges.append((f"set[{i}]", item)) | |
| return edges | |
| if hasattr(obj, "__dict__"): | |
| for key, value in vars(obj).items(): | |
| if key.startswith("_"): | |
| continue | |
| if isinstance(value, (str, int, float, bool, type(None))): | |
| continue | |
| edges.append((f"attr.{key}", value)) | |
| return edges | |
| return edges | |
| def write_object_graph_dot( | |
| root: Any, | |
| output_dot: Path, | |
| max_nodes: int, | |
| max_depth: int, | |
| modules_only: bool, | |
| ) -> tuple[int, int]: | |
| node_ids: dict[int, str] = {} | |
| node_labels: dict[str, str] = {} | |
| node_colors: dict[str, str] = {} | |
| edge_lines: list[str] = [] | |
| q: deque[tuple[Any, int]] = deque([(root, 0)]) | |
| next_id = 0 | |
| while q and len(node_ids) < max_nodes: | |
| obj, depth = q.popleft() | |
| oid = id(obj) | |
| if oid not in node_ids: | |
| node_name = f"n{next_id}" | |
| next_id += 1 | |
| node_ids[oid] = node_name | |
| node_labels[node_name] = short_label(obj) | |
| fill = node_fillcolor(obj) | |
| if fill: | |
| node_colors[node_name] = fill | |
| src = node_ids[oid] | |
| if depth >= max_depth: | |
| continue | |
| for label, child in iter_edges(obj, modules_only=modules_only): | |
| if modules_only and not isinstance(child, nn.Module): | |
| continue | |
| child_id = id(child) | |
| if child_id not in node_ids: | |
| if len(node_ids) >= max_nodes: | |
| break | |
| node_name = f"n{next_id}" | |
| next_id += 1 | |
| node_ids[child_id] = node_name | |
| node_labels[node_name] = short_label(child) | |
| fill = node_fillcolor(child) | |
| if fill: | |
| node_colors[node_name] = fill | |
| q.append((child, depth + 1)) | |
| dst = node_ids[child_id] | |
| safe_label = label.replace('"', '\\"') | |
| edge_lines.append(f' {src} -> {dst} [label="{safe_label}"];') | |
| with output_dot.open("w", encoding="utf-8") as f: | |
| f.write("digraph checkpoint_object_graph {\n") | |
| f.write(" rankdir=LR;\n") | |
| f.write(' node [shape=box, fontsize=10, fontname="Helvetica"];\n') | |
| for node_name, label in node_labels.items(): | |
| safe = label.replace('"', '\\"') | |
| fill = node_colors.get(node_name) | |
| if fill: | |
| f.write(f' {node_name} [label="{safe}", style="filled", fillcolor="{fill}"];\n') | |
| else: | |
| f.write(f' {node_name} [label="{safe}"];\n') | |
| for edge in edge_lines: | |
| f.write(edge + "\n") | |
| f.write("}\n") | |
| return len(node_labels), len(edge_lines) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Visualize a loaded checkpoint object.") | |
| parser.add_argument("checkpoint", nargs="?", default="living-shiner.ckpt") | |
| parser.add_argument("--map-location", default="cpu") | |
| parser.add_argument("--weights-only", action="store_true") | |
| parser.add_argument("--out-dir", default="checkpoint_viz") | |
| parser.add_argument("--max-depth", type=int, default=4) | |
| parser.add_argument("--max-nodes", type=int, default=300) | |
| parser.add_argument( | |
| "--modules-only", | |
| action="store_true", | |
| help="Only include torch.nn.Module objects in the object graph diagram.", | |
| ) | |
| return parser.parse_args() | |
| def main() -> int: | |
| args = parse_args() | |
| ckpt = Path(args.checkpoint) | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| if not ckpt.exists(): | |
| print(f"Checkpoint not found: {ckpt}") | |
| return 1 | |
| try: | |
| obj = load_checkpoint(ckpt, map_location=args.map_location, weights_only=args.weights_only) | |
| except Exception as exc: | |
| print(f"Failed to load checkpoint: {exc}") | |
| return 2 | |
| print_package_versions(read_module_versions_from_checkpoint(ckpt)) | |
| print(f"Loaded: {ckpt}") | |
| print(f"Root type: {obj.__class__.__module__}.{obj.__class__.__name__}") | |
| if isinstance(obj, nn.Module): | |
| tree_path = out_dir / "module_tree.txt" | |
| tree_path.write_text("\n".join(module_tree_lines(obj, max_depth=args.max_depth)) + "\n", encoding="utf-8") | |
| print(f"Wrote module tree: {tree_path}") | |
| else: | |
| print("Root is not an nn.Module; skipping module tree output.") | |
| dot_path = out_dir / "object_graph.dot" | |
| if args.modules_only and not isinstance(obj, nn.Module): | |
| print("Root is not an nn.Module; cannot build a modules-only diagram.") | |
| return 3 | |
| node_count, edge_count = write_object_graph_dot( | |
| obj, | |
| output_dot=dot_path, | |
| max_nodes=args.max_nodes, | |
| max_depth=args.max_depth, | |
| modules_only=args.modules_only, | |
| ) | |
| print(f"Wrote object graph DOT: {dot_path} ({node_count} nodes, {edge_count} edges)") | |
| if shutil.which("dot"): | |
| svg_path = out_dir / "object_graph.svg" | |
| subprocess.run(["dot", "-Tsvg", str(dot_path), "-o", str(svg_path)], check=False) | |
| if svg_path.exists(): | |
| print(f"Wrote SVG: {svg_path}") | |
| else: | |
| print("Graphviz 'dot' not found. Install graphviz to render DOT to SVG/PNG.") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment