Last active
June 25, 2025 19:07
-
-
Save johnmeade/5c18fc5b1a5547e093272ec3aa5f6016 to your computer and use it in GitHub Desktop.
Generic PyTorch Dataloading Benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Generic PyTorch Dataloading Benchmark. | |
| MIT Licence / John Meade. | |
| # installation | |
| pip install torch numpy accelerate | |
| # example | |
| ulimit -n 20000 | |
| accelerate config --config_file "acc-cfg-8gpu.yaml" | |
| accelerate launch --config_file "acc-cfg-8gpu.yaml" bench-dataload-acc.py /mnt/foo/dataset1 -a wavs speech_tokens | |
| accelerate launch --config_file "acc-cfg-8gpu.yaml" bench-dataload-acc.py /mnt/bar/dataset2 -a wavs speech_tokens | |
| """ | |
| import argparse | |
| import librosa | |
| import numpy as np | |
| import torch | |
| import gc | |
| from accelerate import Accelerator | |
| from collections import defaultdict | |
| from itertools import count | |
| from pathlib import Path | |
| from time import perf_counter | |
| from torch.utils.data import Dataset, DataLoader | |
| from tqdm import tqdm | |
| def loadtensor(fp): | |
| fp = Path(fp) | |
| if fp.suffix == ".npy": | |
| ary = np.load(fp) | |
| tensor = torch.from_numpy(ary) | |
| elif fp.suffix == ".pt": | |
| tensor = torch.load(fp) | |
| elif fp.suffix in [".wav", ".flac", ".mp3"]: | |
| ary, _ = librosa.load(str(fp)) | |
| tensor = torch.from_numpy(ary) | |
| else: | |
| print(f"no loader for filetype: {fp}") | |
| tensor = None | |
| return tensor | |
| def pad_or_trim(x, l, dim=0, left=False): | |
| """ | |
| Right pad `dim` of `x` up to length `l` (if lower), and trim `dim` of `x` to length `l` (if larger). | |
| Left pad/trim if `left=True`. | |
| Eg: | |
| >>> x = tensor([[1,2], | |
| [3,4]]) | |
| >>> pad_or_trim(x, l=3, dim=0) | |
| tensor([[1, 2], | |
| [3, 4], | |
| [0, 0]]) | |
| >>> pad_or_trim(x, l=3, dim=1) | |
| tensor([[1, 2, 0], | |
| [3, 4, 0]]) | |
| >>> pad_or_trim(x, l=1, dim=0) | |
| tensor([[1, 2]]) | |
| >>> pad_or_trim(x, l=1, dim=1) | |
| tensor([[1], | |
| [3]]) | |
| """ | |
| xl = x.size(dim) | |
| if xl < l: | |
| x = F.pad(x, | |
| [0, 0] * (x.ndim - dim - 1) + | |
| ([l - x.size(dim), 0] if left else [0, l - x.size(dim)]) | |
| ) | |
| if xl > l: | |
| x = x[ | |
| tuple([slice(None) for _ in range(dim)] + | |
| [slice(-l, None) if left else slice(0, l)] + | |
| [slice(None) for _ in range(x.ndim - dim - 1)]) | |
| ] | |
| return x | |
| _pad_dims = dict() # cache the pad dimension for each data field | |
| def autocollate(batch): | |
| """ | |
| Automatically decide how to collate batch data. Arrays and Tensors are scanned for the dimension | |
| that does not match across samples, and items are padded automatically. Other data is left as | |
| a Python list. | |
| Args: | |
| batch: list of dicts with sample data | |
| """ | |
| global _pad_dims | |
| collated = dict() | |
| fields = list(batch[0]) | |
| for field in fields: | |
| # get list of data for this field | |
| data = [data[field] for data in batch] | |
| # only collate arrays/tensors | |
| if isinstance(data[0], np.ndarray): | |
| data = [torch.from_numpy(d) for d in data] | |
| if not torch.is_tensor(data[0]): | |
| collated[field] = data | |
| continue | |
| # guess dimension to pad: first dimension with unmatched sizes | |
| if field not in _pad_dims: | |
| _pad_dims[field] = 0 | |
| for i in range(len(data[0].shape)): | |
| if any(d.size(i) != data[0].size(i) for d in data[1:]): | |
| _pad_dims[field] = i | |
| break | |
| pad_dim = _pad_dims[field] | |
| # pad tensors | |
| L = max([d.size(pad_dim) for d in data]) | |
| collated[field] = torch.stack([pad_or_trim(x, L, dim=pad_dim) for x in data]) | |
| return collated | |
| class GenericDataset(Dataset): | |
| def __init__(self, data_dir, annots, main_process=False): | |
| self.data_dir = data_dir | |
| if main_process: | |
| print("[ enumerating files ]") | |
| t0 = perf_counter() | |
| datapaths = defaultdict(dict) | |
| counter = count() | |
| order = defaultdict(lambda: next(counter)) | |
| for ann in annots: | |
| for ad in data_dir.glob(f"{ann}*"): | |
| if main_process: | |
| print(f"> {ad}", flush=True) | |
| for fp in ad.iterdir(): | |
| bn = fp.stem | |
| datapaths[bn][ann] = str(fp) | |
| order[bn] | |
| gc.collect() | |
| order = {v: k for k, v in order.items()} | |
| t1 = perf_counter() | |
| if main_process: | |
| print(f"> took {int(1000*(t1 - t0)):,} ms") | |
| self.datapaths, self.order, self.annots = datapaths, order, annots | |
| def __len__(self): | |
| return len(self.datapaths) | |
| def __getitem__(self, idx): | |
| bn = self.order[idx] | |
| datapaths = self.datapaths[bn] | |
| data = dict() | |
| for ann in self.annots: | |
| if ann in datapaths: | |
| data[ann] = loadtensor(datapaths[ann]) | |
| return data | |
| def benchmark_dataloader(data_dir, annotations, batch_size=8, accum_steps=8, num_epochs=1, nproc=16): | |
| accelerator = Accelerator(gradient_accumulation_steps=accum_steps) | |
| device = accelerator.device | |
| n_gpus = accelerator.state.num_processes | |
| main_process = accelerator.is_main_process | |
| if main_process: | |
| print(f"> Batch size per GPU: {batch_size}") | |
| print(f"> Num GPU: {n_gpus}") | |
| print(f"> Accumulation steps: {accum_steps}") | |
| print(f"> Total Batch size: {batch_size * accum_steps * n_gpus}", flush=True) | |
| dataset = GenericDataset(data_dir, annotations, main_process) | |
| if main_process: | |
| print(f"dataset size: {len(dataset)}") | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=nproc, | |
| collate_fn=autocollate, | |
| ) | |
| dataloader = accelerator.prepare(dataloader) | |
| if main_process: | |
| print(f"Starting main loop") | |
| for epoch in range(num_epochs): | |
| start_time = perf_counter() | |
| for batch in tqdm(iter(dataloader), total=len(dataloader), disable=(not main_process)): | |
| for v in batch.values(): | |
| v.cuda() + 1 | |
| end_time = perf_counter() | |
| if main_process: | |
| print(f"Epoch {epoch + 1}: {end_time - start_time:.4f} seconds", flush=True) | |
| # torch.multiprocessing.set_sharing_strategy('file_system') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('data_dir', type=Path) | |
| parser.add_argument('-a', '--annotations', nargs="+") | |
| args = parser.parse_args() | |
| benchmark_dataloader(args.data_dir, args.annotations, batch_size=8, accum_steps=8, nproc=16) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment