Skip to content

Instantly share code, notes, and snippets.

@leoleoasd
Created March 1, 2026 05:41
Show Gist options
  • Select an option

  • Save leoleoasd/65dbb1c115ca2cb1a5ba89bc456e5fdf to your computer and use it in GitHub Desktop.

Select an option

Save leoleoasd/65dbb1c115ca2cb1a5ba89bc456e5fdf to your computer and use it in GitHub Desktop.
NCCL Broadcast/Gather file
"""
nccl_file_ops.py
Broadcast files/directories across Ray actors using torch.distributed + NCCL.
Supports async pipelined chunked transfers to overlap disk I/O with network.
Usage:
import ray
from nccl_file_ops import broadcast_files
ray.init()
# Broadcast a directory from this node to all others
group = broadcast_files("/data/model_weights/")
# Reuse for another broadcast
ray.get(group.broadcast.remote("/data/dataset/"))
# Cleanup
ray.get(group.teardown.remote())
"""
import os
import ctypes
import json
import math
import time
import logging
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional
from ray.experimental.tqdm_ray import tqdm as ray_tqdm
import ray
import torch
import torch.distributed as dist
from shared.utils import get_random_free_port
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# libc pread / pwrite — bypass Python buffer allocation
# ---------------------------------------------------------------------------
_libc = ctypes.CDLL("libc.so.6", use_errno=True)
_libc.pread.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int64]
_libc.pread.restype = ctypes.c_ssize_t
_libc.pwrite.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int64]
_libc.pwrite.restype = ctypes.c_ssize_t
def _libc_pread_all(fd: int, ptr: int, length: int, offset: int) -> int:
"""pread() loop until all bytes read. ptr is a raw C pointer (int)."""
total = 0
while total < length:
n = _libc.pread(fd, ptr + total, length - total, offset + total)
if n <= 0:
if n < 0:
errno = ctypes.get_errno()
raise OSError(errno, f"pread failed: {os.strerror(errno)}")
break # EOF
total += n
return total
def _libc_pwrite_all(fd: int, ptr: int, length: int, offset: int) -> int:
"""pwrite() loop until all bytes written. ptr is a raw C pointer (int)."""
total = 0
while total < length:
n = _libc.pwrite(fd, ptr + total, length - total, offset + total)
if n <= 0:
if n < 0:
errno = ctypes.get_errno()
raise OSError(errno, f"pwrite failed: {os.strerror(errno)}")
break
total += n
return total
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
@dataclass
class BroadcastConfig:
"""Configuration for file broadcast."""
chunk_size: int = 1024 * 1024 * 1024 # 1 GB per chunk
num_buffers: int = 10 # buffer pool size (controls overlap depth)
device: Optional[torch.device] = None # defaults to cuda:current
def __post_init__(self):
if self.device is None:
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
@dataclass
class FileEntry:
"""Metadata for a single file to broadcast."""
rel_path: str
size: int
num_chunks: int
def chunk_ranges(self, chunk_size: int) -> list[tuple[int, int]]:
ranges = []
for i in range(self.num_chunks):
offset = i * chunk_size
length = min(chunk_size, self.size - offset)
ranges.append((offset, length))
return ranges
@dataclass
class BroadcastManifest:
"""Manifest describing all files to broadcast."""
files: list[FileEntry] = field(default_factory=list)
chunk_size: int = 256 * 1024 * 1024
src_is_file: bool = False # True when source was a single file (not a directory)
def to_json(self) -> str:
return json.dumps(
{
"chunk_size": self.chunk_size,
"src_is_file": self.src_is_file,
"files": [
{"rel_path": f.rel_path, "size": f.size, "num_chunks": f.num_chunks}
for f in self.files
],
}
)
@staticmethod
def from_json(s: str) -> "BroadcastManifest":
d = json.loads(s)
return BroadcastManifest(
chunk_size=d["chunk_size"],
src_is_file=d.get("src_is_file", False),
files=[
FileEntry(
rel_path=f["rel_path"],
size=f["size"],
num_chunks=f["num_chunks"],
)
for f in d["files"]
],
)
@dataclass
class ChunkDescriptor:
"""One chunk in the flattened broadcast pipeline."""
file_idx: int # index into manifest.files
chunk_idx: int # chunk index within the file
num_chunks: int # total chunks in this file
rel_path: str # for logging
offset: int # byte offset in the file
length: int # bytes in this chunk
@dataclass
class AllGatherTransfer:
"""One broadcast in the all_gather plan: a src_rank broadcasts files."""
src_rank: int
files: list[FileEntry] = field(default_factory=list)
@dataclass
class AllGatherPlan:
"""Plan for all_gather: sequence of broadcasts, each from a different rank."""
transfers: list[AllGatherTransfer] = field(default_factory=list)
chunk_size: int = 1024 * 1024 * 1024
def to_json(self) -> str:
return json.dumps(
{
"chunk_size": self.chunk_size,
"transfers": [
{
"src_rank": t.src_rank,
"files": [
{
"rel_path": f.rel_path,
"size": f.size,
"num_chunks": f.num_chunks,
}
for f in t.files
],
}
for t in self.transfers
],
}
)
@staticmethod
def from_json(s: str) -> "AllGatherPlan":
d = json.loads(s)
return AllGatherPlan(
chunk_size=d["chunk_size"],
transfers=[
AllGatherTransfer(
src_rank=t["src_rank"],
files=[
FileEntry(
rel_path=f["rel_path"],
size=f["size"],
num_chunks=f["num_chunks"],
)
for f in t["files"]
],
)
for t in d["transfers"]
],
)
# ---------------------------------------------------------------------------
# Worker actor — one per node
# ---------------------------------------------------------------------------
@ray.remote(num_gpus=1)
class FileTransferWorker:
"""
A single worker in the file transfer group.
Each worker corresponds to one rank in the torch.distributed group.
"""
def __init__(self):
self.rank: Optional[int] = None
self.world_size: Optional[int] = None
self._initialized = False
# Configure logging in this actor process so logger.info() is visible
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
force=True,
)
def init_process_group(
self,
rank: int,
world_size: int,
master_addr: str,
master_port: int,
):
"""Initialize torch.distributed for this worker."""
self.rank = rank
self.world_size = world_size
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
logger.info(
f"[Rank {rank}] init_process_group: "
f"master={master_addr}:{master_port} world_size={world_size}"
)
t0 = time.monotonic()
dist.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size,
)
elapsed_ms = (time.monotonic() - t0) * 1000
self._initialized = True
logger.info(f"[Rank {rank}] init_process_group: done in {elapsed_ms:.0f}ms")
# Warm up the PyTorch CUDA context now so it doesn't add latency
# to the first broadcast. init_process_group uses NCCL's own CUDA
# context, but PyTorch's context is lazy-initialized on first use.
t0 = time.monotonic()
_dummy = torch.zeros(1, device="cuda")
del _dummy
cuda_ms = (time.monotonic() - t0) * 1000
logger.info(f"[Rank {rank}] init_process_group: cuda warmup={cuda_ms:.0f}ms")
def teardown(self):
if self._initialized:
logger.info(f"[Rank {self.rank}] teardown: destroying process group")
t0 = time.monotonic()
dist.destroy_process_group()
elapsed_ms = (time.monotonic() - t0) * 1000
self._initialized = False
logger.info(f"[Rank {self.rank}] teardown: done in {elapsed_ms:.0f}ms")
else:
logger.info(f"[Rank {self.rank}] teardown: not initialized, skipping")
def get_node_ip(self) -> str:
return ray.util.get_node_ip_address()
# -------------------------------------------------------------------
# Internal helpers
# -------------------------------------------------------------------
def _get_device(self, config: BroadcastConfig) -> torch.device:
if config.device is not None:
return config.device
return torch.device(f"cuda:{torch.cuda.current_device()}")
def _build_manifest(self, src_path: str, chunk_size: int) -> BroadcastManifest:
t0 = time.monotonic()
src = Path(src_path)
manifest = BroadcastManifest(chunk_size=chunk_size)
if src.is_file():
manifest.src_is_file = True
size = src.stat().st_size
manifest.files.append(
FileEntry(
rel_path=src.name,
size=size,
num_chunks=max(1, math.ceil(size / chunk_size)),
)
)
elif src.is_dir():
for fp in sorted(src.rglob("*")):
if fp.is_file():
size = fp.stat().st_size
manifest.files.append(
FileEntry(
rel_path=str(fp.relative_to(src)),
size=size,
num_chunks=max(1, math.ceil(size / chunk_size)),
)
)
else:
raise FileNotFoundError(f"Source path not found: {src_path}")
total_bytes = sum(f.size for f in manifest.files)
total_chunks = sum(f.num_chunks for f in manifest.files)
elapsed_ms = (time.monotonic() - t0) * 1000
logger.info(
f"[Rank {self.rank}] _build_manifest: {len(manifest.files)} files, "
f"{total_bytes / 1e9:.2f} GB, {total_chunks} total chunks, "
f"src_is_file={manifest.src_is_file}, took {elapsed_ms:.0f}ms"
)
return manifest
def _broadcast_manifest(
self,
manifest: Optional[BroadcastManifest],
src_rank: int,
device: torch.device,
) -> BroadcastManifest:
logger.info(
f"[Rank {self.rank}] _broadcast_manifest: begin "
f"(is_src={self.rank == src_rank})"
)
t_total = time.monotonic()
if self.rank == src_rank:
data = manifest.to_json().encode("utf-8")
size_t = torch.tensor([len(data)], dtype=torch.long, device=device)
logger.info(
f"[Rank {self.rank}] _broadcast_manifest: "
f"manifest json={len(data)} bytes"
)
else:
size_t = torch.tensor([0], dtype=torch.long, device=device)
logger.info(
f"[Rank {self.rank}] _broadcast_manifest: size_t={size_t}; {time.monotonic() - t_total:.1f}s"
)
dist.barrier()
logger.info(
f"[Rank {self.rank}] _broadcast_manifest: barrier done {time.monotonic() - t_total:.1f}s"
)
t0 = time.monotonic()
dist.broadcast(size_t, src=src_rank)
ms = (time.monotonic() - t0) * 1000
manifest_size = size_t.item()
logger.info(
f"[Rank {self.rank}] _broadcast_manifest: "
f"size_broadcast={ms:.1f}ms (manifest_size={manifest_size} bytes)"
)
if self.rank == src_rank:
payload = torch.frombuffer(bytearray(data), dtype=torch.uint8).to(device)
else:
payload = torch.empty(size_t.item(), dtype=torch.uint8, device=device)
t0 = time.monotonic()
dist.broadcast(payload, src=src_rank)
ms = (time.monotonic() - t0) * 1000
logger.info(
f"[Rank {self.rank}] _broadcast_manifest: payload_broadcast={ms:.1f}ms"
)
if self.rank != src_rank:
json_str = payload.cpu().numpy().tobytes().decode("utf-8")
manifest = BroadcastManifest.from_json(json_str)
logger.info(
f"[Rank {self.rank}] _broadcast_manifest: "
f"received {len(manifest.files)} files, "
f"src_is_file={manifest.src_is_file}"
)
total_ms = (time.monotonic() - t_total) * 1000
logger.info(f"[Rank {self.rank}] _broadcast_manifest: done in {total_ms:.0f}ms")
return manifest
def _broadcast_file_chunked(
self,
entry: FileEntry,
src_path: Optional[Path],
dst_path: Path,
src_rank: int,
config: BroadcastConfig,
):
"""Legacy single-file wrapper. Use _broadcast_all_chunks for pipelined transfers."""
raise NotImplementedError("Use _broadcast_all_chunks instead")
# ------------------------------------------------------------------
# Chunk-based pipeline: flattened across all files
# ------------------------------------------------------------------
def _flatten_chunks(self, manifest: BroadcastManifest) -> list[ChunkDescriptor]:
"""Flatten all (file, chunk) pairs into a single ordered list."""
chunks: list[ChunkDescriptor] = []
for file_idx, entry in enumerate(manifest.files):
if entry.size == 0:
continue # empty files handled separately
for chunk_idx, (offset, length) in enumerate(
entry.chunk_ranges(manifest.chunk_size)
):
chunks.append(
ChunkDescriptor(
file_idx=file_idx,
chunk_idx=chunk_idx,
num_chunks=entry.num_chunks,
rel_path=entry.rel_path,
offset=offset,
length=length,
)
)
return chunks
# Header: 8 bytes prepended to every broadcast payload
# [0:4] chunk_id (int32 LE)
# [4:8] chunk_len (int32 LE)
CHUNK_HEADER_SIZE = 8
def _broadcast_all_chunks(
self,
manifest: BroadcastManifest,
src_root: Optional[Path],
dst_root: Path,
src_rank: int,
config: BroadcastConfig,
pbar=None,
bench_mode: Optional[str] = None,
):
"""
Broadcast all files using a pool of worker threads, each owning one
CPU pinned buffer + one GPU buffer. Chunks are sent in arbitrary
order; a small header (chunk_id, chunk_len) is prepended so receivers
can identify each chunk without prior knowledge of the send order.
Sender:
Workers: pick chunk from work queue, disk_read into cpu_buf
(after header), H2D copy whole buffer, signal ready.
Main: pop ready_q (unordered), NCCL broadcast, give slot back.
Receiver:
Main: grab free slot, NCCL broadcast into gpu_buf, parse header
to learn chunk_id, hand (slot, chunk_id) to worker.
Workers: D2H copy → pwrite to disk, return slot to free pool.
All ranks broadcast the same fixed size (HEADER + max_chunk) so that
NCCL stays in sync regardless of actual chunk lengths.
"""
import queue
import struct
import threading
HEADER = self.CHUNK_HEADER_SIZE
is_src = self.rank == src_rank
device = self._get_device(config)
num_workers = config.num_buffers
tag = "SND" if is_src else "RCV"
flat_chunks = self._flatten_chunks(manifest)
total_chunks = len(flat_chunks)
total_bytes = sum(f.size for f in manifest.files)
if total_chunks == 0:
logger.info(f"[Rank {self.rank} {tag}] no chunks to transfer")
if pbar is not None:
pbar.update(0)
return
max_chunk = max(c.length for c in flat_chunks)
bcast_size = HEADER + max_chunk # fixed broadcast size for every call
logger.info(
f"[Rank {self.rank} {tag}] pipeline: {total_chunks} chunks across "
f"{len(manifest.files)} files, {total_bytes / 1e9:.2f} GB, "
f"max_chunk={max_chunk / 1e6:.0f} MB, bcast={bcast_size / 1e6:.0f} MB, "
f"workers={num_workers}"
)
# ------------------------------------------------------------------
# Allocate buffer pool: one (pinned CPU + GPU) pair per worker
# Each buffer is HEADER + max_chunk bytes.
# Each worker also gets its own CUDA stream for H2D / D2H copies.
# ------------------------------------------------------------------
t0 = time.monotonic()
cpu_bufs = [
torch.empty(bcast_size, dtype=torch.uint8, pin_memory=True)
for _ in range(num_workers)
]
gpu_bufs = [
torch.empty(bcast_size, dtype=torch.uint8, device=device)
for _ in range(num_workers)
]
copy_streams = [torch.cuda.Stream(device=device) for _ in range(num_workers)]
alloc_ms = (time.monotonic() - t0) * 1000
logger.info(
f"[Rank {self.rank} {tag}] allocated {num_workers}x{bcast_size / 1e6:.0f} MB "
f"pinned + GPU in {alloc_ms:.0f}ms"
)
# NCCL warmup
t0 = time.monotonic()
dummy = torch.zeros(1, dtype=torch.uint8, device=device)
dist.broadcast(dummy, src=src_rank)
warmup_ms = (time.monotonic() - t0) * 1000
logger.info(f"[Rank {self.rank} {tag}] nccl warmup={warmup_ms:.1f}ms")
pipeline_t0 = time.monotonic()
# ==================================================================
# SENDER
# ==================================================================
if is_src:
if bench_mode == "sender":
# Sender bench: reuse one GPU buffer, only update header.
# No disk read, no H2D, no worker threads.
gpu_buf = gpu_bufs[0]
header_np = cpu_bufs[0].numpy()
for i in range(total_chunks):
iter_t0 = time.monotonic()
desc = flat_chunks[i]
chunk_len = desc.length
# Stamp header into GPU via tiny CPU→GPU
struct.pack_into("<II", header_np, 0, i, chunk_len)
gpu_buf[:HEADER].copy_(cpu_bufs[0][:HEADER])
t0 = time.monotonic()
dist.broadcast(gpu_buf, src=src_rank)
nccl_ms = (time.monotonic() - t0) * 1000
elapsed = (time.monotonic() - pipeline_t0) * 1000
logger.info(
f"[Rank {self.rank} {tag}] [{i}/{total_chunks}] "
f"{desc.rel_path} c{desc.chunk_idx}/{desc.num_chunks} "
f"slot=0 chunk_id={i} BENCH "
f"nccl={nccl_ms:.1f}ms "
f"iter={(time.monotonic() - iter_t0) * 1000:.1f}ms "
f"T+{elapsed:.0f}ms ({chunk_len / 1e6:.1f} MB)"
)
if pbar is not None:
pbar.update(chunk_len)
else:
import os as _os_snd
# ready_q: workers put (slot, chunk_index) after read+h2d
ready_q: queue.Queue = queue.Queue()
# work_q: main thread puts chunk_index for workers
work_q: queue.Queue = queue.Queue()
worker_error = None
# Open file descriptors for pread (no seek, no lock needed)
file_fds_snd: dict[int, int] = {}
for file_idx, entry in enumerate(manifest.files):
if entry.size == 0:
continue
file_fds_snd[file_idx] = _os_snd.open(
str(src_root / entry.rel_path), _os_snd.O_RDONLY
)
def _sender_worker(worker_id):
"""Each worker owns slot == worker_id."""
nonlocal worker_error
slot = worker_id
try:
while True:
item = work_q.get()
if item is None:
return
chunk_i = item
desc = flat_chunks[chunk_i]
# Write header into cpu_buf[slot][0:HEADER]
struct.pack_into(
"<II", cpu_bufs[slot].numpy(), 0, chunk_i, desc.length
)
# Disk read via libc pread directly into pinned buf
fd = file_fds_snd[desc.file_idx]
buf_ptr = cpu_bufs[slot].data_ptr() + HEADER
t_read = time.monotonic()
_libc_pread_all(fd, buf_ptr, desc.length, desc.offset)
read_ms = (time.monotonic() - t_read) * 1000
# H2D copy on worker's own CUDA stream
t_h2d = time.monotonic()
with torch.cuda.stream(copy_streams[slot]):
gpu_bufs[slot].copy_(cpu_bufs[slot])
copy_streams[slot].synchronize()
h2d_ms = (time.monotonic() - t_h2d) * 1000
logger.info(
f"[Rank {self.rank} {tag}] PREP [{chunk_i}/{total_chunks}] "
f"{desc.rel_path} c{desc.chunk_idx}/{desc.num_chunks} "
f"slot={slot} read={read_ms:.0f}ms h2d={h2d_ms:.0f}ms "
f"({desc.length / 1e6:.1f} MB)"
)
ready_q.put((slot, chunk_i))
except Exception as e:
worker_error = e
ready_q.put(None)
# Start worker threads
workers = []
for wid in range(num_workers):
t = threading.Thread(
target=_sender_worker, args=(wid,), daemon=True
)
t.start()
workers.append(t)
# Seed initial work: one chunk per worker (up to total_chunks)
next_chunk = 0
for _ in range(min(num_workers, total_chunks)):
work_q.put(next_chunk)
next_chunk += 1
# Main loop: pop ready (unordered), NCCL broadcast, give slot back
for i in range(total_chunks):
iter_t0 = time.monotonic()
item = ready_q.get()
if item is None:
if worker_error is not None:
raise worker_error
break
slot, chunk_i = item
desc = flat_chunks[chunk_i]
chunk_len = desc.length
prep_wait_ms = (time.monotonic() - iter_t0) * 1000
# NCCL broadcast (fixed size — all ranks use bcast_size)
t0 = time.monotonic()
dist.broadcast(gpu_bufs[slot], src=src_rank)
nccl_ms = (time.monotonic() - t0) * 1000
# Dispatch next chunk to this worker's slot
if next_chunk < total_chunks:
work_q.put(next_chunk)
next_chunk += 1
elapsed = (time.monotonic() - pipeline_t0) * 1000
logger.info(
f"[Rank {self.rank} {tag}] [{i}/{total_chunks}] "
f"{desc.rel_path} c{desc.chunk_idx}/{desc.num_chunks} "
f"slot={slot} chunk_id={chunk_i} prep_wait={prep_wait_ms:.1f}ms "
f"nccl={nccl_ms:.1f}ms "
f"iter={(time.monotonic() - iter_t0) * 1000:.1f}ms "
f"T+{elapsed:.0f}ms ({chunk_len / 1e6:.1f} MB)"
)
if pbar is not None:
pbar.update(chunk_len)
# Shut down workers
for _ in workers:
work_q.put(None)
for t in workers:
t.join()
# Close file descriptors
for fd in file_fds_snd.values():
_os_snd.close(fd)
# ==================================================================
# RECEIVER
# ==================================================================
else:
if bench_mode == "receiver":
# Bench mode: NCCL recv only, no D2H / disk write.
# Reuse a single GPU buffer (slot 0) for all receives.
for i in range(total_chunks):
iter_t0 = time.monotonic()
t0 = time.monotonic()
dist.broadcast(gpu_bufs[0], src=src_rank)
nccl_ms = (time.monotonic() - t0) * 1000
# Parse header to report progress
header_cpu = gpu_bufs[0][:HEADER].cpu().numpy().tobytes()
chunk_id, chunk_len = struct.unpack_from("<II", header_cpu, 0)
desc = flat_chunks[chunk_id]
elapsed = (time.monotonic() - pipeline_t0) * 1000
logger.info(
f"[Rank {self.rank} {tag}] [{i}/{total_chunks}] "
f"{desc.rel_path} c{desc.chunk_idx}/{desc.num_chunks} "
f"slot=0 chunk_id={chunk_id} BENCH "
f"nccl={nccl_ms:.1f}ms "
f"iter={(time.monotonic() - iter_t0) * 1000:.1f}ms "
f"T+{elapsed:.0f}ms ({chunk_len / 1e6:.1f} MB)"
)
if pbar is not None:
pbar.update(chunk_len)
else:
import os as _os
# Pre-allocate output files
file_fds: dict[int, int] = {}
t0 = time.monotonic()
for entry_idx, entry in enumerate(manifest.files):
if entry.size == 0:
continue
dst_path = dst_root / entry.rel_path
dst_path.parent.mkdir(parents=True, exist_ok=True)
fd = _os.open(
str(dst_path), _os.O_CREAT | _os.O_WRONLY | _os.O_TRUNC
)
_os.ftruncate(fd, entry.size)
file_fds[entry_idx] = fd
prealloc_ms = (time.monotonic() - t0) * 1000
logger.info(
f"[Rank {self.rank} {tag}] pre-allocated {len(file_fds)} files "
f"in {prealloc_ms:.0f}ms"
)
# free_slots: slots available for main thread to recv into
free_slots: queue.Queue[int] = queue.Queue()
for s in range(num_workers):
free_slots.put(s)
# write_q: main thread puts (slot, chunk_id) for workers
write_q: queue.Queue = queue.Queue()
worker_error = None
def _recv_worker(worker_id):
"""Worker does D2H + pwrite, then returns slot."""
nonlocal worker_error
try:
while True:
item = write_q.get()
if item is None:
return
slot, chunk_id, chunk_len = item
desc = flat_chunks[chunk_id]
# D2H copy on worker's own CUDA stream
payload_end = HEADER + chunk_len
t_d2h = time.monotonic()
with torch.cuda.stream(copy_streams[slot]):
cpu_bufs[slot][:payload_end].copy_(
gpu_bufs[slot][:payload_end]
)
copy_streams[slot].synchronize()
d2h_ms = (time.monotonic() - t_d2h) * 1000
# pwrite to disk via libc (skip header)
fd = file_fds[desc.file_idx]
buf_ptr = cpu_bufs[slot].data_ptr() + HEADER
t_w = time.monotonic()
_libc_pwrite_all(fd, buf_ptr, chunk_len, desc.offset)
write_ms = (time.monotonic() - t_w) * 1000
logger.info(
f"[Rank {self.rank} {tag}] WRITE [{chunk_id}/{total_chunks}] "
f"{desc.rel_path} c{desc.chunk_idx}/{desc.num_chunks} "
f"slot={slot} d2h={d2h_ms:.0f}ms write={write_ms:.0f}ms "
f"({chunk_len / 1e6:.1f} MB)"
)
# Return slot to main thread
free_slots.put(slot)
except Exception as e:
worker_error = e
free_slots.put(-1) # unblock main
# Start worker threads
workers = []
for wid in range(num_workers):
t = threading.Thread(target=_recv_worker, args=(wid,), daemon=True)
t.start()
workers.append(t)
# Main loop: grab free slot, NCCL recv, parse header, hand to worker
for i in range(total_chunks):
iter_t0 = time.monotonic()
t0 = time.monotonic()
slot = free_slots.get()
if slot == -1 and worker_error is not None:
raise worker_error
slot_wait_ms = (time.monotonic() - t0) * 1000
# NCCL receive (fixed size)
t0 = time.monotonic()
dist.broadcast(gpu_bufs[slot], src=src_rank)
nccl_ms = (time.monotonic() - t0) * 1000
# Parse header on GPU — copy just 8 bytes to CPU
header_cpu = gpu_bufs[slot][:HEADER].cpu().numpy().tobytes()
chunk_id, chunk_len = struct.unpack_from("<II", header_cpu, 0)
desc = flat_chunks[chunk_id]
# Hand off to worker for D2H + write
write_q.put((slot, chunk_id, chunk_len))
elapsed = (time.monotonic() - pipeline_t0) * 1000
logger.info(
f"[Rank {self.rank} {tag}] [{i}/{total_chunks}] "
f"{desc.rel_path} c{desc.chunk_idx}/{desc.num_chunks} "
f"slot={slot} chunk_id={chunk_id} slot_wait={slot_wait_ms:.1f}ms "
f"nccl={nccl_ms:.1f}ms "
f"iter={(time.monotonic() - iter_t0) * 1000:.1f}ms "
f"T+{elapsed:.0f}ms ({chunk_len / 1e6:.1f} MB)"
)
if pbar is not None:
pbar.update(chunk_len)
# Shut down workers and wait
for _ in workers:
write_q.put(None)
t0 = time.monotonic()
for t in workers:
t.join()
if worker_error is not None:
raise worker_error
# Close file descriptors
for fd in file_fds.values():
_os.close(fd)
flush_ms = (time.monotonic() - t0) * 1000
tp = total_bytes / (flush_ms / 1000) / 1e9 if flush_ms > 0 else 0
logger.info(
f"[Rank {self.rank} {tag}] flush wait: {flush_ms:.0f}ms ({tp:.1f} GB/s)"
)
torch.cuda.synchronize(device)
pipeline_ms = (time.monotonic() - pipeline_t0) * 1000
tp = total_bytes / (pipeline_ms / 1000) / 1e9 if pipeline_ms > 0 else 0
logger.info(
f"[Rank {self.rank} {tag}] pipeline done: {total_chunks} chunks in "
f"{pipeline_ms:.0f}ms ({tp:.1f} GB/s)"
)
# -------------------------------------------------------------------
# Public API
# -------------------------------------------------------------------
def broadcast(
self,
src_path: str,
dst_dir: Optional[str] = None,
src_rank: int = 0,
chunk_size: int = 256 * 1024 * 1024,
num_buffers: int = 10,
bench_mode: Optional[str] = None,
):
"""
Broadcast file(s) from src_rank to all other ranks.
Args:
src_path: Path to file or directory on source rank.
dst_dir: Destination directory on all ranks.
Defaults to same path as src_path.
src_rank: Rank that owns the source files.
chunk_size: Bytes per chunk.
num_buffers: Buffer pool size (controls overlap depth).
bench_mode: "sender" to skip disk reads (send junk data),
"receiver" to skip D2H + disk writes.
Both measure pure NCCL throughput from respective side.
"""
assert self._initialized, "Call init_process_group first"
broadcast_t0 = time.monotonic()
config = BroadcastConfig(chunk_size=chunk_size, num_buffers=num_buffers)
device = self._get_device(config)
is_src = self.rank == src_rank
logger.info(
f"[Rank {self.rank}] broadcast: begin src_path={src_path} "
f"dst_dir={dst_dir} is_src={is_src} "
f"chunk_size={chunk_size / 1e6:.0f}MB num_buffers={num_buffers}"
)
# Phase 1: manifest
if is_src:
manifest = self._build_manifest(src_path, config.chunk_size)
total_bytes = sum(f.size for f in manifest.files)
logger.info(
f"[Rank {self.rank}] broadcast: built manifest, "
f"{len(manifest.files)} file(s), {total_bytes / 1e9:.2f} GB"
)
else:
manifest = None
manifest = self._broadcast_manifest(manifest, src_rank, device)
# Phase 2: compute dst_root
# For single files, dst_root is the parent directory so that
# dst_root / rel_path recreates the file at the right location.
# For directories, dst_root is the directory itself.
if dst_dir is not None:
dst_root = Path(dst_dir)
elif manifest.src_is_file:
dst_root = Path(src_path).parent
else:
dst_root = Path(src_path)
# src_root: for files, use parent so src_root / rel_path = original path
src_root = Path(src_path)
if is_src and manifest.src_is_file:
src_root = src_root.parent
logger.info(
f"[Rank {self.rank}] broadcast: src_root={src_root} dst_root={dst_root} "
f"transferring {len(manifest.files)} file(s)"
)
# Handle empty files (0 bytes) that need to be created on receivers
for entry in manifest.files:
if entry.size == 0:
if not is_src:
dst_file = dst_root / entry.rel_path
dst_file.parent.mkdir(parents=True, exist_ok=True)
dst_file.touch()
logger.info(
f"[Rank {self.rank}] broadcast: "
f"{'skipped' if is_src else 'created'} empty file {entry.rel_path}"
)
# Chunk-based pipeline for all non-empty files
total_bytes = sum(f.size for f in manifest.files)
pbar = ray_tqdm(
total=total_bytes,
desc=f"[Rank {self.rank}] broadcast",
unit="B",
)
self._broadcast_all_chunks(
manifest=manifest,
src_root=src_root if is_src else None,
dst_root=dst_root,
src_rank=src_rank,
config=config,
pbar=pbar,
bench_mode=bench_mode,
)
pbar.close()
t0 = time.monotonic()
dist.barrier()
barrier_ms = (time.monotonic() - t0) * 1000
total_ms = (time.monotonic() - broadcast_t0) * 1000
throughput = total_bytes / (total_ms / 1000) / 1e9 if total_ms > 0 else 0
logger.info(
f"[Rank {self.rank}] broadcast: complete "
f"total={total_ms:.0f}ms barrier={barrier_ms:.0f}ms "
f"throughput={throughput:.2f} GB/s"
)
def all_gather(
self,
src_dir: str,
dst_dir: Optional[str] = None,
chunk_size: int = 256 * 1024 * 1024,
num_buffers: int = 10,
bench_mode: Optional[str] = None,
):
"""
All-gather files: each rank contributes files from src_dir,
and at the end every rank has all files in dst_dir.
If src_dir == dst_dir (or dst_dir is None), each rank already has
its own files and only receives files from other ranks.
Args:
src_dir: Directory containing this rank's files.
dst_dir: Destination directory. Defaults to src_dir.
chunk_size: Bytes per NCCL chunk.
num_buffers: Buffer pool size.
bench_mode: If True, receivers only do NCCL recv without
D2H copy or disk writes (for throughput benchmarking).
"""
assert self._initialized, "Call init_process_group first"
all_gather_t0 = time.monotonic()
config = BroadcastConfig(chunk_size=chunk_size, num_buffers=num_buffers)
device = self._get_device(config)
if dst_dir is None:
dst_dir = src_dir
src_root = Path(src_dir)
dst_root = Path(dst_dir)
logger.info(
f"[Rank {self.rank}] all_gather: begin src_dir={src_dir} "
f"dst_dir={dst_dir} chunk_size={chunk_size / 1e6:.0f}MB "
f"num_buffers={num_buffers}"
)
# Phase 1: each rank builds local manifest
t0 = time.monotonic()
if src_root.is_dir():
local_manifest = self._build_manifest(src_dir, config.chunk_size)
else:
# Directory doesn't exist or is empty — empty manifest
local_manifest = BroadcastManifest(chunk_size=config.chunk_size)
local_manifest.src_is_file = False
logger.info(
f"[Rank {self.rank}] all_gather: local manifest "
f"{len(local_manifest.files)} files, "
f"{sum(f.size for f in local_manifest.files) / 1e9:.2f} GB "
f"in {(time.monotonic() - t0) * 1000:.0f}ms"
)
# Phase 2: gather all manifests to rank 0
all_manifests = self._gather_manifests(local_manifest, device)
# Phase 3: rank 0 builds plan, broadcasts to all
plan = self._build_and_broadcast_plan(all_manifests, config.chunk_size, device)
logger.info(
f"[Rank {self.rank}] all_gather: plan has "
f"{len(plan.transfers)} transfers, "
f"{sum(len(t.files) for t in plan.transfers)} total files"
)
# Phase 4: execute each transfer (broadcast from src_rank)
total_bytes_all = sum(f.size for t in plan.transfers for f in t.files)
pbar = ray_tqdm(
total=total_bytes_all,
desc=f"[Rank {self.rank}] all_gather {self.rank}",
unit="B",
)
total_transferred = 0
for ti, transfer in enumerate(plan.transfers):
if not transfer.files:
continue
is_src = self.rank == transfer.src_rank
manifest = BroadcastManifest(
chunk_size=plan.chunk_size,
src_is_file=False,
files=transfer.files,
)
transfer_bytes = sum(f.size for f in transfer.files)
logger.info(
f"[Rank {self.rank}] all_gather: transfer {ti}/{len(plan.transfers)} "
f"from rank {transfer.src_rank}, "
f"{len(transfer.files)} files, {transfer_bytes / 1e9:.2f} GB"
)
# Handle empty files
for entry in manifest.files:
if entry.size == 0 and not is_src:
dst_file = dst_root / entry.rel_path
dst_file.parent.mkdir(parents=True, exist_ok=True)
dst_file.touch()
# Transfer non-empty files
self._broadcast_all_chunks(
manifest=manifest,
src_root=src_root if is_src else None,
dst_root=dst_root,
src_rank=transfer.src_rank,
config=config,
pbar=pbar,
bench_mode=bench_mode,
)
total_transferred += transfer_bytes
pbar.close()
t0 = time.monotonic()
dist.barrier()
barrier_ms = (time.monotonic() - t0) * 1000
total_ms = (time.monotonic() - all_gather_t0) * 1000
throughput = total_transferred / (total_ms / 1000) / 1e9 if total_ms > 0 else 0
logger.info(
f"[Rank {self.rank}] all_gather: complete "
f"total={total_ms:.0f}ms barrier={barrier_ms:.0f}ms "
f"transferred={total_transferred / 1e9:.2f} GB "
f"throughput={throughput:.2f} GB/s"
)
def _gather_manifests(
self,
local_manifest: BroadcastManifest,
device: torch.device,
) -> Optional[list[BroadcastManifest]]:
"""
Gather all ranks' manifests to rank 0.
Returns list of manifests on rank 0, None on other ranks.
"""
t0 = time.monotonic()
# Serialize local manifest
local_json = local_manifest.to_json().encode("utf-8")
local_size = len(local_json)
# All-gather sizes so each rank knows how much data to expect
size_tensor = torch.tensor([local_size], dtype=torch.long, device=device)
size_list = [
torch.zeros(1, dtype=torch.long, device=device)
for _ in range(self.world_size)
]
dist.all_gather(size_list, size_tensor)
sizes = [int(s.item()) for s in size_list]
max_size = max(sizes)
logger.info(
f"[Rank {self.rank}] _gather_manifests: sizes={sizes} max_size={max_size}"
)
# Pad local payload to max_size and all-gather
local_payload = torch.zeros(max_size, dtype=torch.uint8, device=device)
local_bytes = torch.frombuffer(bytearray(local_json), dtype=torch.uint8)
local_payload[:local_size] = local_bytes.to(device)
payload_list = [
torch.zeros(max_size, dtype=torch.uint8, device=device)
for _ in range(self.world_size)
]
dist.all_gather(payload_list, local_payload)
gather_ms = (time.monotonic() - t0) * 1000
logger.info(f"[Rank {self.rank}] _gather_manifests: done in {gather_ms:.0f}ms")
# Only rank 0 needs to decode all manifests
if self.rank == 0:
manifests = []
for r in range(self.world_size):
json_bytes = payload_list[r][: sizes[r]].cpu().numpy().tobytes()
manifests.append(
BroadcastManifest.from_json(json_bytes.decode("utf-8"))
)
return manifests
return None
def _build_and_broadcast_plan(
self,
all_manifests: Optional[list[BroadcastManifest]],
chunk_size: int,
device: torch.device,
) -> AllGatherPlan:
"""
Rank 0 builds the all_gather plan and broadcasts it to all ranks.
"""
if self.rank == 0:
assert all_manifests is not None
plan = AllGatherPlan(chunk_size=chunk_size)
# Deduplicate: first rank with a file wins
seen_files: dict[str, int] = {} # rel_path -> src_rank
per_rank_files: dict[int, list[FileEntry]] = {
r: [] for r in range(self.world_size)
}
for rank, manifest in enumerate(all_manifests):
for entry in manifest.files:
if entry.rel_path not in seen_files:
seen_files[entry.rel_path] = rank
per_rank_files[rank].append(entry)
# Build transfers — one per rank that has files to contribute
for rank in range(self.world_size):
files = per_rank_files[rank]
if files:
plan.transfers.append(AllGatherTransfer(src_rank=rank, files=files))
total_files = sum(len(t.files) for t in plan.transfers)
total_bytes = sum(f.size for t in plan.transfers for f in t.files)
logger.info(
f"[Rank 0] _build_all_gather_plan: "
f"{total_files} unique files from "
f"{len(plan.transfers)} ranks, "
f"{total_bytes / 1e9:.2f} GB total"
)
plan_json = plan.to_json().encode("utf-8")
plan_size = torch.tensor([len(plan_json)], dtype=torch.long, device=device)
else:
plan_json = b""
plan_size = torch.tensor([0], dtype=torch.long, device=device)
# Broadcast plan size then payload
dist.broadcast(plan_size, src=0)
size = int(plan_size.item())
if self.rank == 0:
payload = torch.frombuffer(bytearray(plan_json), dtype=torch.uint8).to(
device
)
else:
payload = torch.empty(size, dtype=torch.uint8, device=device)
dist.broadcast(payload, src=0)
if self.rank != 0:
json_str = payload.cpu().numpy().tobytes().decode("utf-8")
plan = AllGatherPlan.from_json(json_str)
logger.info(
f"[Rank {self.rank}] _build_and_broadcast_plan: done, "
f"{len(plan.transfers)} transfers"
)
return plan
# ---------------------------------------------------------------------------
# Coordinator actor — manages the worker group
# ---------------------------------------------------------------------------
@ray.remote
class FileTransferGroup:
"""
Manages a group of FileTransferWorker actors across the Ray cluster.
Automatically spawns one worker per GPU node using node affinity.
Usage:
group = FileTransferGroup.remote()
ray.get(group.setup.remote())
ray.get(group.broadcast.remote(src_path="/data/weights/"))
ray.get(group.teardown.remote())
"""
def __init__(self):
self.workers: list = []
self.world_size: int = 0
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
force=True,
)
def setup(
self,
master_port: Optional[int] = None,
worker_options: Optional[dict] = None,
src_node_ip: Optional[str] = None,
):
"""
Create one worker per node and initialize torch.distributed.
Automatically discovers all GPU nodes in the Ray cluster and pins
one worker to each node via NodeAffinitySchedulingStrategy.
Args:
master_port: Port for torch.distributed rendezvous.
worker_options: Extra kwargs for FileTransferWorker.options()
e.g. {"num_cpus": 4}
src_node_ip: If provided, the worker on this node gets rank 0.
Defaults to driver node IP in broadcast_files().
"""
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
setup_t0 = time.monotonic()
logger.info(
f"[Group] setup: begin master_port={master_port} src_node_ip={src_node_ip}"
)
if worker_options is None:
worker_options = {}
if master_port is None:
master_port = get_random_free_port()
logger.info(f"[Group] setup: allocated random port {master_port}")
# Discover all live GPU nodes
nodes = [
n for n in ray.nodes() if n["Alive"] and n["Resources"].get("GPU", 0) > 0
]
if not nodes:
raise RuntimeError("No alive GPU nodes found in the Ray cluster")
self.world_size = len(nodes)
node_ips = [n["NodeManagerAddress"] for n in nodes]
logger.info(f"[Group] setup: found {self.world_size} GPU nodes: {node_ips}")
# Sort so src_node gets index 0
if src_node_ip is not None:
nodes.sort(key=lambda n: (n["NodeManagerAddress"] != src_node_ip))
if nodes[0]["NodeManagerAddress"] != src_node_ip:
available = {n["NodeManagerAddress"] for n in nodes}
raise ValueError(
f"No GPU node found at {src_node_ip}. Available: {available}"
)
sorted_ips = [n["NodeManagerAddress"] for n in nodes]
logger.info(
f"[Group] setup: rank assignment: {sorted_ips} (rank 0 = {sorted_ips[0]})"
)
# Spawn one worker per node, pinned via node affinity
t0 = time.monotonic()
self.workers = []
for node in nodes:
strategy = NodeAffinitySchedulingStrategy(
node_id=node["NodeID"],
soft=False,
)
worker = FileTransferWorker.options(
scheduling_strategy=strategy,
**worker_options,
).remote()
self.workers.append(worker)
spawn_ms = (time.monotonic() - t0) * 1000
logger.info(
f"[Group] setup: spawned {len(self.workers)} workers in {spawn_ms:.0f}ms"
)
# Use rank 0's IP as master
master_addr = nodes[0]["NodeManagerAddress"]
# Initialize torch.distributed on all workers in parallel
logger.info(
f"[Group] setup: init_process_group on all workers "
f"(master={master_addr}:{master_port})"
)
t0 = time.monotonic()
init_futures = [
w.init_process_group.remote(
rank=i,
world_size=self.world_size,
master_addr=master_addr,
master_port=master_port,
)
for i, w in enumerate(self.workers)
]
ray.get(init_futures)
init_ms = (time.monotonic() - t0) * 1000
total_ms = (time.monotonic() - setup_t0) * 1000
logger.info(
f"[Group] setup: done in {total_ms:.0f}ms "
f"(spawn={spawn_ms:.0f}ms init={init_ms:.0f}ms) "
f"{self.world_size} workers on {sorted_ips}"
)
def broadcast(
self,
src_path: str,
dst_dir: Optional[str] = None,
chunk_size: int = 256 * 1024 * 1024,
num_buffers: int = 10,
bench_mode: Optional[str] = None,
):
"""
Broadcast file(s) from rank 0 to all ranks.
All workers execute in parallel; returns when all finish.
dst_dir defaults to same path as src_path.
"""
logger.info(
f"[Group] broadcast: begin src_path={src_path} dst_dir={dst_dir} "
f"chunk_size={chunk_size / 1e6:.0f}MB num_buffers={num_buffers} "
f"bench_mode={bench_mode} workers={len(self.workers)}"
)
t0 = time.monotonic()
futures = [
w.broadcast.remote(
src_path=src_path,
dst_dir=dst_dir,
src_rank=0,
chunk_size=chunk_size,
num_buffers=num_buffers,
bench_mode=bench_mode,
)
for w in self.workers
]
ray.get(futures)
elapsed_ms = (time.monotonic() - t0) * 1000
logger.info(f"[Group] broadcast: done in {elapsed_ms:.0f}ms")
def all_gather(
self,
src_dir: str,
dst_dir: Optional[str] = None,
chunk_size: int = 256 * 1024 * 1024,
num_buffers: int = 10,
bench_mode: Optional[str] = None,
):
"""
All-gather files from all ranks.
Each rank contributes files from src_dir. At the end every rank
has all files in dst_dir.
dst_dir defaults to src_dir.
"""
logger.info(
f"[Group] all_gather: begin src_dir={src_dir} dst_dir={dst_dir} "
f"chunk_size={chunk_size / 1e6:.0f}MB num_buffers={num_buffers} "
f"bench_mode={bench_mode} workers={len(self.workers)}"
)
t0 = time.monotonic()
futures = [
w.all_gather.remote(
src_dir=src_dir,
dst_dir=dst_dir,
chunk_size=chunk_size,
num_buffers=num_buffers,
bench_mode=bench_mode,
)
for w in self.workers
]
ray.get(futures)
elapsed_ms = (time.monotonic() - t0) * 1000
logger.info(f"[Group] all_gather: done in {elapsed_ms:.0f}ms")
def teardown(self):
"""Destroy process groups and kill workers."""
logger.info(f"[Group] teardown: begin ({len(self.workers)} workers)")
t0 = time.monotonic()
if self.workers:
ray.get([w.teardown.remote() for w in self.workers])
for w in self.workers:
ray.kill(w)
self.workers = []
elapsed_ms = (time.monotonic() - t0) * 1000
logger.info(f"[Group] teardown: done in {elapsed_ms:.0f}ms")
# ---------------------------------------------------------------------------
# Convenience function
# ---------------------------------------------------------------------------
def broadcast_files(
src_path: str,
dst_dir: Optional[str] = None,
chunk_size: int = 1024 * 1024 * 1024,
num_buffers: int = 10,
master_port: Optional[int] = None,
worker_options: Optional[dict] = None,
bench_mode: Optional[str] = None,
) -> FileTransferGroup:
"""
One-shot convenience: create group, broadcast, return group for reuse.
Call this on the node that has the source files. That node becomes
rank 0 automatically. One worker is created per GPU node in the cluster.
Args:
src_path: File or directory to broadcast.
dst_dir: Destination directory on all ranks. Defaults to src_path.
chunk_size: Bytes per NCCL chunk.
num_buffers: GPU double/triple buffering.
master_port: torch.distributed rendezvous port.
worker_options: Extra ray.remote options for workers.
Returns:
The FileTransferGroup actor (reusable for more operations).
Example:
group = broadcast_files("/data/model/")
ray.get(group.broadcast.remote("/data/dataset/"))
ray.get(group.teardown.remote())
"""
driver_ip = ray.util.get_node_ip_address()
logger.info(
f"broadcast_files: begin src_path={src_path} dst_dir={dst_dir} "
f"chunk_size={chunk_size / 1e6:.0f}MB num_buffers={num_buffers} "
f"driver_ip={driver_ip}"
)
overall_t0 = time.monotonic()
group = FileTransferGroup.remote()
t0 = time.monotonic()
ray.get(
group.setup.remote(
master_port=master_port,
worker_options=worker_options,
src_node_ip=driver_ip,
)
)
setup_ms = (time.monotonic() - t0) * 1000
logger.info(f"broadcast_files: setup done in {setup_ms:.0f}ms")
t0 = time.monotonic()
ray.get(
group.broadcast.remote(
src_path=src_path,
dst_dir=dst_dir,
chunk_size=chunk_size,
num_buffers=num_buffers,
bench_mode=bench_mode,
)
)
broadcast_ms = (time.monotonic() - t0) * 1000
total_ms = (time.monotonic() - overall_t0) * 1000
logger.info(
f"broadcast_files: done total={total_ms:.0f}ms "
f"(setup={setup_ms:.0f}ms broadcast={broadcast_ms:.0f}ms)"
)
return group
def all_gather_files(
src_dir: str,
dst_dir: Optional[str] = None,
chunk_size: int = 1024 * 1024 * 1024,
num_buffers: int = 10,
master_port: Optional[int] = None,
worker_options: Optional[dict] = None,
bench_mode: Optional[str] = None,
) -> FileTransferGroup:
"""
All-gather files across all GPU nodes.
Each node contributes files from src_dir. At the end every node has
all files in dst_dir (defaults to src_dir).
Args:
src_dir: Directory containing this node's files.
dst_dir: Destination directory. Defaults to src_dir.
chunk_size: Bytes per NCCL chunk.
num_buffers: Buffer pool size.
master_port: torch.distributed rendezvous port.
worker_options: Extra ray.remote options for workers.
Returns:
The FileTransferGroup actor (reusable for more operations).
"""
driver_ip = ray.util.get_node_ip_address()
logger.info(
f"all_gather_files: begin src_dir={src_dir} dst_dir={dst_dir} "
f"chunk_size={chunk_size / 1e6:.0f}MB num_buffers={num_buffers} "
f"driver_ip={driver_ip}"
)
overall_t0 = time.monotonic()
group = FileTransferGroup.remote()
t0 = time.monotonic()
ray.get(
group.setup.remote(
master_port=master_port,
worker_options=worker_options,
src_node_ip=driver_ip,
)
)
setup_ms = (time.monotonic() - t0) * 1000
logger.info(f"all_gather_files: setup done in {setup_ms:.0f}ms")
t0 = time.monotonic()
ray.get(
group.all_gather.remote(
src_dir=src_dir,
dst_dir=dst_dir,
chunk_size=chunk_size,
num_buffers=num_buffers,
bench_mode=bench_mode,
)
)
gather_ms = (time.monotonic() - t0) * 1000
total_ms = (time.monotonic() - overall_t0) * 1000
logger.info(
f"all_gather_files: done total={total_ms:.0f}ms "
f"(setup={setup_ms:.0f}ms gather={gather_ms:.0f}ms)"
)
return group
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment