Skip to content

Instantly share code, notes, and snippets.

@ashvinnihalani
Created May 29, 2025 01:51
Show Gist options
  • Select an option

  • Save ashvinnihalani/a90d61d27054ba6bf45ba7f37bdd7711 to your computer and use it in GitHub Desktop.

Select an option

Save ashvinnihalani/a90d61d27054ba6bf45ba7f37bdd7711 to your computer and use it in GitHub Desktop.
Checkpoint Statistcs
"""
Launch like:
OMP_NUM_THREADS=1 LOGLEVEL=WARNING torchrun --rdzv-backend static --master-addr ${VC_TRAIN_0_HOSTS:-$(hostname --fqdn)} \
--node-rank $NODE_INDEX --nnodes $NNODES --nproc-per-node 8 /shared/workspace/rynli/check_nans_ckpt.py
"""
import os
import pickle
import io
import torch.distributed as dist
from torch.distributed.checkpoint.utils import _create_file_view
from src_patch.utils.s3path import S3Path
from typing import cast, IO
import torch
from torch import Tensor
from pathlib import Path
from functools import partial
import lovely_tensors as lt
import pandas as pd
lt.monkey_patch()
def _slice_file(file, sinfo):
return _create_file_view(file, sinfo.offset, sinfo.length)
def resolve_tensor(stream, item_md):
file_slice = _slice_file(stream, item_md)
try:
tensor = cast(
Tensor,
torch.load(cast(IO[bytes], file_slice), map_location="cpu"),
)
for idx, t in enumerate(tensor):
if isinstance(t, io.BytesIO):
t.seek(0)
tensor[idx] = torch.load(t, map_location="cpu", weights_only=True)
except Exception as e:
read_bytes = io.BytesIO(file_slice.read(item_md.length))
read_bytes.seek(0)
tensor = torch.load(read_bytes,map_location="cpu", weights_only=True)
return tensor
def p_tensor(item, path, file_cache=None, return_tensor=False):
md_idx, sinfo = item
f = get_stream(path, sinfo, file_cache) if file_cache else (path / sinfo.relative_path).open("rb")
tensor = resolve_tensor(f, sinfo)[0]
if return_tensor:
return tensor
print(f"{md_idx.fqn}:\n{tensor}")
def get_stream(path, sinfo, cache):
fpath = path / sinfo.relative_path
if fpath not in cache:
cache[fpath] = fpath.open("rb")
return cache[fpath]
def sort_key(item):
return item[0].fqn
prefix = "N/A"
csv_path = "/shared/workspace/ashvinn/tensor_stats"
def main() -> None:
steps = [
"megatron_gpt--step=28000-consumed_samples=29320000.0"
]
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
for step in steps:
ckpt = f"{prefix}/{step}/weights"
new_pth = S3Path(ckpt)
with (new_pth / '.metadata').open("rb") as f:
new_md = pickle.loads(f.read())
new_items = list(new_md.storage_data.items())
new_sorted = sorted(new_items, key=sort_key)
p_new = partial(p_tensor, path=new_pth)
range_to_cover = (0, len(new_sorted))
num_total = range_to_cover[1] - range_to_cover[0]
per_rank = num_total // world_size
start = range_to_cover[0] + rank * per_rank
end = start + per_rank if rank < world_size - 1 else range_to_cover[1]
CSV_FILE = f"{csv_path}_rank_{rank}_checkpoint_{step}.csv"
if os.path.exists(CSV_FILE):
df = pd.read_csv(CSV_FILE)
else:
df = pd.DataFrame(columns=["tensor_name", "lovely_tensor", "nans", "large"])
df.to_csv(CSV_FILE, index=False)
for i in range(start, end):
if i in df.index and df.loc[i, "tensor_name"] == new_sorted[i][0].fqn:
continue
df.loc[i] = [new_sorted[i][0].fqn, str(p_new(new_sorted[i], file_cache=None, return_tensor=True)), 0, 0 ]
df.loc[[i]].to_csv(CSV_FILE, mode="a", index=False, header=False)
dist.barrier()
if __name__ == "__main__":
dist.init_process_group(backend="gloo")
main()
dist.barrier()
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment