Created
May 29, 2025 01:51
-
-
Save ashvinnihalani/a90d61d27054ba6bf45ba7f37bdd7711 to your computer and use it in GitHub Desktop.
Checkpoint Statistcs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Launch like: | |
| OMP_NUM_THREADS=1 LOGLEVEL=WARNING torchrun --rdzv-backend static --master-addr ${VC_TRAIN_0_HOSTS:-$(hostname --fqdn)} \ | |
| --node-rank $NODE_INDEX --nnodes $NNODES --nproc-per-node 8 /shared/workspace/rynli/check_nans_ckpt.py | |
| """ | |
| import os | |
| import pickle | |
| import io | |
| import torch.distributed as dist | |
| from torch.distributed.checkpoint.utils import _create_file_view | |
| from src_patch.utils.s3path import S3Path | |
| from typing import cast, IO | |
| import torch | |
| from torch import Tensor | |
| from pathlib import Path | |
| from functools import partial | |
| import lovely_tensors as lt | |
| import pandas as pd | |
| lt.monkey_patch() | |
| def _slice_file(file, sinfo): | |
| return _create_file_view(file, sinfo.offset, sinfo.length) | |
| def resolve_tensor(stream, item_md): | |
| file_slice = _slice_file(stream, item_md) | |
| try: | |
| tensor = cast( | |
| Tensor, | |
| torch.load(cast(IO[bytes], file_slice), map_location="cpu"), | |
| ) | |
| for idx, t in enumerate(tensor): | |
| if isinstance(t, io.BytesIO): | |
| t.seek(0) | |
| tensor[idx] = torch.load(t, map_location="cpu", weights_only=True) | |
| except Exception as e: | |
| read_bytes = io.BytesIO(file_slice.read(item_md.length)) | |
| read_bytes.seek(0) | |
| tensor = torch.load(read_bytes,map_location="cpu", weights_only=True) | |
| return tensor | |
| def p_tensor(item, path, file_cache=None, return_tensor=False): | |
| md_idx, sinfo = item | |
| f = get_stream(path, sinfo, file_cache) if file_cache else (path / sinfo.relative_path).open("rb") | |
| tensor = resolve_tensor(f, sinfo)[0] | |
| if return_tensor: | |
| return tensor | |
| print(f"{md_idx.fqn}:\n{tensor}") | |
| def get_stream(path, sinfo, cache): | |
| fpath = path / sinfo.relative_path | |
| if fpath not in cache: | |
| cache[fpath] = fpath.open("rb") | |
| return cache[fpath] | |
| def sort_key(item): | |
| return item[0].fqn | |
| prefix = "N/A" | |
| csv_path = "/shared/workspace/ashvinn/tensor_stats" | |
| def main() -> None: | |
| steps = [ | |
| "megatron_gpt--step=28000-consumed_samples=29320000.0" | |
| ] | |
| world_size = int(os.environ["WORLD_SIZE"]) | |
| rank = int(os.environ["RANK"]) | |
| for step in steps: | |
| ckpt = f"{prefix}/{step}/weights" | |
| new_pth = S3Path(ckpt) | |
| with (new_pth / '.metadata').open("rb") as f: | |
| new_md = pickle.loads(f.read()) | |
| new_items = list(new_md.storage_data.items()) | |
| new_sorted = sorted(new_items, key=sort_key) | |
| p_new = partial(p_tensor, path=new_pth) | |
| range_to_cover = (0, len(new_sorted)) | |
| num_total = range_to_cover[1] - range_to_cover[0] | |
| per_rank = num_total // world_size | |
| start = range_to_cover[0] + rank * per_rank | |
| end = start + per_rank if rank < world_size - 1 else range_to_cover[1] | |
| CSV_FILE = f"{csv_path}_rank_{rank}_checkpoint_{step}.csv" | |
| if os.path.exists(CSV_FILE): | |
| df = pd.read_csv(CSV_FILE) | |
| else: | |
| df = pd.DataFrame(columns=["tensor_name", "lovely_tensor", "nans", "large"]) | |
| df.to_csv(CSV_FILE, index=False) | |
| for i in range(start, end): | |
| if i in df.index and df.loc[i, "tensor_name"] == new_sorted[i][0].fqn: | |
| continue | |
| df.loc[i] = [new_sorted[i][0].fqn, str(p_new(new_sorted[i], file_cache=None, return_tensor=True)), 0, 0 ] | |
| df.loc[[i]].to_csv(CSV_FILE, mode="a", index=False, header=False) | |
| dist.barrier() | |
| if __name__ == "__main__": | |
| dist.init_process_group(backend="gloo") | |
| main() | |
| dist.barrier() | |
| dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment