Skip to content

Instantly share code, notes, and snippets.

@johnmeade
Last active June 25, 2025 19:07
Show Gist options
  • Select an option

  • Save johnmeade/5c18fc5b1a5547e093272ec3aa5f6016 to your computer and use it in GitHub Desktop.

Select an option

Save johnmeade/5c18fc5b1a5547e093272ec3aa5f6016 to your computer and use it in GitHub Desktop.
Generic PyTorch Dataloading Benchmark
"""
Generic PyTorch Dataloading Benchmark.
MIT Licence / John Meade.
# installation
pip install torch numpy accelerate
# example
ulimit -n 20000
accelerate config --config_file "acc-cfg-8gpu.yaml"
accelerate launch --config_file "acc-cfg-8gpu.yaml" bench-dataload-acc.py /mnt/foo/dataset1 -a wavs speech_tokens
accelerate launch --config_file "acc-cfg-8gpu.yaml" bench-dataload-acc.py /mnt/bar/dataset2 -a wavs speech_tokens
"""
import argparse
import librosa
import numpy as np
import torch
import gc
from accelerate import Accelerator
from collections import defaultdict
from itertools import count
from pathlib import Path
from time import perf_counter
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
def loadtensor(fp):
fp = Path(fp)
if fp.suffix == ".npy":
ary = np.load(fp)
tensor = torch.from_numpy(ary)
elif fp.suffix == ".pt":
tensor = torch.load(fp)
elif fp.suffix in [".wav", ".flac", ".mp3"]:
ary, _ = librosa.load(str(fp))
tensor = torch.from_numpy(ary)
else:
print(f"no loader for filetype: {fp}")
tensor = None
return tensor
def pad_or_trim(x, l, dim=0, left=False):
"""
Right pad `dim` of `x` up to length `l` (if lower), and trim `dim` of `x` to length `l` (if larger).
Left pad/trim if `left=True`.
Eg:
>>> x = tensor([[1,2],
[3,4]])
>>> pad_or_trim(x, l=3, dim=0)
tensor([[1, 2],
[3, 4],
[0, 0]])
>>> pad_or_trim(x, l=3, dim=1)
tensor([[1, 2, 0],
[3, 4, 0]])
>>> pad_or_trim(x, l=1, dim=0)
tensor([[1, 2]])
>>> pad_or_trim(x, l=1, dim=1)
tensor([[1],
[3]])
"""
xl = x.size(dim)
if xl < l:
x = F.pad(x,
[0, 0] * (x.ndim - dim - 1) +
([l - x.size(dim), 0] if left else [0, l - x.size(dim)])
)
if xl > l:
x = x[
tuple([slice(None) for _ in range(dim)] +
[slice(-l, None) if left else slice(0, l)] +
[slice(None) for _ in range(x.ndim - dim - 1)])
]
return x
_pad_dims = dict() # cache the pad dimension for each data field
def autocollate(batch):
"""
Automatically decide how to collate batch data. Arrays and Tensors are scanned for the dimension
that does not match across samples, and items are padded automatically. Other data is left as
a Python list.
Args:
batch: list of dicts with sample data
"""
global _pad_dims
collated = dict()
fields = list(batch[0])
for field in fields:
# get list of data for this field
data = [data[field] for data in batch]
# only collate arrays/tensors
if isinstance(data[0], np.ndarray):
data = [torch.from_numpy(d) for d in data]
if not torch.is_tensor(data[0]):
collated[field] = data
continue
# guess dimension to pad: first dimension with unmatched sizes
if field not in _pad_dims:
_pad_dims[field] = 0
for i in range(len(data[0].shape)):
if any(d.size(i) != data[0].size(i) for d in data[1:]):
_pad_dims[field] = i
break
pad_dim = _pad_dims[field]
# pad tensors
L = max([d.size(pad_dim) for d in data])
collated[field] = torch.stack([pad_or_trim(x, L, dim=pad_dim) for x in data])
return collated
class GenericDataset(Dataset):
def __init__(self, data_dir, annots, main_process=False):
self.data_dir = data_dir
if main_process:
print("[ enumerating files ]")
t0 = perf_counter()
datapaths = defaultdict(dict)
counter = count()
order = defaultdict(lambda: next(counter))
for ann in annots:
for ad in data_dir.glob(f"{ann}*"):
if main_process:
print(f"> {ad}", flush=True)
for fp in ad.iterdir():
bn = fp.stem
datapaths[bn][ann] = str(fp)
order[bn]
gc.collect()
order = {v: k for k, v in order.items()}
t1 = perf_counter()
if main_process:
print(f"> took {int(1000*(t1 - t0)):,} ms")
self.datapaths, self.order, self.annots = datapaths, order, annots
def __len__(self):
return len(self.datapaths)
def __getitem__(self, idx):
bn = self.order[idx]
datapaths = self.datapaths[bn]
data = dict()
for ann in self.annots:
if ann in datapaths:
data[ann] = loadtensor(datapaths[ann])
return data
def benchmark_dataloader(data_dir, annotations, batch_size=8, accum_steps=8, num_epochs=1, nproc=16):
accelerator = Accelerator(gradient_accumulation_steps=accum_steps)
device = accelerator.device
n_gpus = accelerator.state.num_processes
main_process = accelerator.is_main_process
if main_process:
print(f"> Batch size per GPU: {batch_size}")
print(f"> Num GPU: {n_gpus}")
print(f"> Accumulation steps: {accum_steps}")
print(f"> Total Batch size: {batch_size * accum_steps * n_gpus}", flush=True)
dataset = GenericDataset(data_dir, annotations, main_process)
if main_process:
print(f"dataset size: {len(dataset)}")
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=nproc,
collate_fn=autocollate,
)
dataloader = accelerator.prepare(dataloader)
if main_process:
print(f"Starting main loop")
for epoch in range(num_epochs):
start_time = perf_counter()
for batch in tqdm(iter(dataloader), total=len(dataloader), disable=(not main_process)):
for v in batch.values():
v.cuda() + 1
end_time = perf_counter()
if main_process:
print(f"Epoch {epoch + 1}: {end_time - start_time:.4f} seconds", flush=True)
# torch.multiprocessing.set_sharing_strategy('file_system')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('data_dir', type=Path)
parser.add_argument('-a', '--annotations', nargs="+")
args = parser.parse_args()
benchmark_dataloader(args.data_dir, args.annotations, batch_size=8, accum_steps=8, nproc=16)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment