Skip to content

Instantly share code, notes, and snippets.

@zhuangh
Created December 2, 2025 16:35
Show Gist options
  • Select an option

  • Save zhuangh/3e26751984312938fd0fc61f9bd50722 to your computer and use it in GitHub Desktop.

Select an option

Save zhuangh/3e26751984312938fd0fc61f9bd50722 to your computer and use it in GitHub Desktop.
5d_parallelism
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
# ==========================================
# 0. UTILITY: Differentiable All-to-All
# ==========================================
class AllToAll(torch.autograd.Function):
"""
A custom autograd function that allows gradients to flow through
the all_to_all communication primitive.
"""
@staticmethod
def forward(ctx, input_tensor, group):
ctx.group = group
# input_tensor: [World, Capacity, Hidden]
output_tensor = torch.empty_like(input_tensor)
dist.all_to_all_single(output_tensor, input_tensor, group=group)
return output_tensor
@staticmethod
def backward(ctx, grad_output):
# The backward of All-to-All is simply another All-to-All
# (Symmetric communication)
grad_input = torch.empty_like(grad_output)
dist.all_to_all_single(grad_input, grad_output, group=ctx.group)
return grad_input, None
# ==========================================
# 1. COMPONENT: Context Parallel (Ring Attention)
# ==========================================
class RingAttention(nn.Module):
def __init__(self, hidden_dim, mesh):
super().__init__()
self.mesh = mesh
self.pg = mesh.get_group()
self.head_dim = hidden_dim // 4
self.proj_q = nn.Linear(hidden_dim, hidden_dim)
self.proj_k = nn.Linear(hidden_dim, hidden_dim)
self.proj_v = nn.Linear(hidden_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
local_seq = x.size(1)
q = self.proj_q(x)
k = self.proj_k(x)
v = self.proj_v(x)
# Rank calculations
group_rank = dist.get_rank(self.pg)
group_size = dist.get_world_size(self.pg)
next_rank_idx = (group_rank + 1) % group_size
prev_rank_idx = (group_rank - 1 + group_size) % group_size
next_global_rank = dist.get_global_rank(self.pg, next_rank_idx)
prev_global_rank = dist.get_global_rank(self.pg, prev_rank_idx)
curr_k, curr_v = k, v
recv_k = torch.zeros_like(k)
recv_v = torch.zeros_like(v)
attn_out = torch.zeros_like(q)
# Ring Loop
for step in range(group_size):
# 1. Compute
scores = torch.matmul(q, curr_k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_out += torch.matmul(scores, curr_v)
if step == group_size - 1:
break
# 2. Communicate
reqs = [
dist.isend(curr_k, dst=next_global_rank),
dist.isend(curr_v, dst=next_global_rank),
dist.irecv(recv_k, src=prev_global_rank),
dist.irecv(recv_v, src=prev_global_rank)
]
for req in reqs:
req.wait()
curr_k = recv_k.clone()
curr_v = recv_v.clone()
return self.out(attn_out)
# ==========================================
# 2. COMPONENT: Expert Parallel (MoE MLP)
# ==========================================
class MoELayer(nn.Module):
def __init__(self, hidden_dim, ffn_dim, mesh):
super().__init__()
self.mesh = mesh
self.pg = mesh.get_group()
self.num_experts = dist.get_world_size(self.pg)
self.router = nn.Linear(hidden_dim, self.num_experts)
self.expert = nn.Sequential(
nn.Linear(hidden_dim, ffn_dim),
nn.ReLU(),
nn.Linear(ffn_dim, hidden_dim)
)
def forward(self, x):
b, s, h = x.shape
x_flat = x.view(-1, h)
# 1. Route
logits = self.router(x_flat)
_, indices = torch.max(logits, dim=1)
capacity = x_flat.size(0) // self.num_experts
dispatched_x = torch.zeros(self.num_experts, capacity, h)
counts = torch.zeros(self.num_experts, dtype=torch.long)
# 2. Assign to slots (Naive CPU implementation)
# Note: The 'indices' (torch.max) operation is non-differentiable.
# This breaks gradient flow to the Router, but NOT to the Expert.
# To fix Router training, we would need to multiply by Softmax probabilities.
for i in range(x_flat.size(0)):
dest = indices[i]
if counts[dest] < capacity:
dispatched_x[dest, counts[dest]] = x_flat[i]
counts[dest] += 1
# 3. Dispatch (Differentiable)
# FIX: Use the custom Autograd function
recv_buffer = AllToAll.apply(dispatched_x, self.pg)
# 4. Expert Compute
expert_out = self.expert(recv_buffer)
# 5. Combine (Differentiable)
# FIX: Use the custom Autograd function
final_out_sorted = AllToAll.apply(expert_out, self.pg)
out_flat = final_out_sorted.view(b*s, h)
return out_flat.view(b, s, h)
# ==========================================
# 3. The 5D Training Loop
# ==========================================
def train_5d(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12391"
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# --- Mesh Setup ---
mesh = init_device_mesh("cpu", (2, 2, 2), mesh_dim_names=("pp", "ep", "cp"))
pp_mesh = mesh["pp"]
ep_mesh = mesh["ep"]
cp_mesh = mesh["cp"]
pp_rank = mesh.get_coordinate()[0]
if pp_rank == 0:
model = nn.Sequential(
RingAttention(hidden_dim=32, mesh=cp_mesh),
MoELayer(hidden_dim=32, ffn_dim=64, mesh=ep_mesh)
)
else:
model = nn.Sequential(
RingAttention(hidden_dim=32, mesh=cp_mesh),
MoELayer(hidden_dim=32, ffn_dim=64, mesh=ep_mesh),
nn.Linear(32, 10)
)
fsdp_model = fully_shard(model, mesh=ep_mesh)
optimizer = optim.AdamW(fsdp_model.parameters(), lr=1e-3)
# --- Training ---
local_seq_len = 8
hidden_dim = 32
batch_size = 4
for step in range(3):
optimizer.zero_grad()
if pp_rank == 0:
torch.manual_seed(rank + step)
inputs = torch.randn(batch_size, local_seq_len, hidden_dim)
out = fsdp_model(inputs)
peer = rank + 4
dist.send(out.detach(), dst=peer)
grad = torch.zeros_like(out)
dist.recv(grad, src=peer)
# Now out.backward() will work because MoELayer is differentiable
out.backward(grad)
elif pp_rank == 1:
peer = rank - 4
inputs = torch.zeros(batch_size, local_seq_len, hidden_dim)
dist.recv(inputs, src=peer)
inputs.requires_grad = True
out = fsdp_model(inputs)
targets = torch.randint(0, 10, (batch_size * local_seq_len,))
loss = nn.CrossEntropyLoss()(out.view(-1, 10), targets)
if step % 2 == 0:
print(f"Rank {rank} (Stage 1): Loss {loss.item():.4f}")
loss.backward()
# Check for NaN or None gradients to debug early
if inputs.grad is None:
print(f"Rank {rank}: FATAL - inputs.grad is None!")
else:
dist.send(inputs.grad, dst=peer)
optimizer.step()
print(f"Rank {rank}: Finished. Barrier...")
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
world_size = 8
mp.set_start_method("spawn", force=True)
print(f"Running Final 5D Parallelism on {world_size} Processes...")
mp.spawn(train_5d, args=(world_size,), nprocs=world_size, join=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment