Skip to content

Instantly share code, notes, and snippets.

@zhuangh
Created December 2, 2025 16:18
Show Gist options
  • Select an option

  • Save zhuangh/2ebf56e9d5011cd212961639fa3a71b1 to your computer and use it in GitHub Desktop.

Select an option

Save zhuangh/2ebf56e9d5011cd212961639fa3a71b1 to your computer and use it in GitHub Desktop.
3d_parallelism
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel
)
# ==========================================
# 1. FIXED Model Definition
# ==========================================
class TransformerBlock(nn.Module):
def __init__(self, hidden_dim, ffn_dim):
super().__init__()
# FIX: Split 'Attn' into Up/Down to satisfy TP "Sandwich" rule.
# Flow: Replicated -> Colwise (Split) -> Rowwise (Replicated)
self.attn_up = nn.Linear(hidden_dim, hidden_dim)
self.attn_down = nn.Linear(hidden_dim, hidden_dim)
# MLP Part
self.fc1 = nn.Linear(hidden_dim, ffn_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(ffn_dim, hidden_dim)
def forward(self, x):
# 1. Attention Simulation (TP Safe)
# x is [32] -> attn_up (Col) -> [16] -> attn_down (Row) -> [32]
h = self.attn_up(x)
h = self.attn_down(h)
x = x + h # Residual 1 (Safe: 32 + 32)
# 2. MLP Block (TP Safe)
# x is [32] -> fc1 (Col) -> [32] -> fc2 (Row) -> [32]
h = self.fc2(self.relu(self.fc1(x)))
return x + h # Residual 2
# ==========================================
# 2. The Training Function
# ==========================================
def train_step(rank, world_size):
# --- Setup ---
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12377"
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# --- 3D Mesh ---
mesh_3d = init_device_mesh("cpu", (2, 2, 2), mesh_dim_names=("pp", "dp", "tp"))
tp_mesh = mesh_3d["tp"]
dp_mesh = mesh_3d["dp"]
pp_mesh = mesh_3d["pp"]
pp_rank = mesh_3d.get_coordinate()[0]
# --- Model Init ---
model_part = None
if pp_rank == 0:
model_part = nn.Sequential(
TransformerBlock(32, 64),
TransformerBlock(32, 64)
)
elif pp_rank == 1:
model_part = nn.Sequential(
TransformerBlock(32, 64),
TransformerBlock(32, 64),
nn.Linear(32, 10) # Head (Keep replicated, no TP)
)
# --- FIXED TP Plan ---
tp_plan = {
# Attention Sandwich
"attn_up": ColwiseParallel(),
"attn_down": RowwiseParallel(),
# MLP Sandwich
"fc1": ColwiseParallel(),
"fc2": RowwiseParallel(),
# NOTE: We do NOT parallelize the final "Head" (Linear(32,10)).
# This keeps it replicated so CrossEntropyLoss sees full [B, 10] output.
}
for layer in model_part:
parallelize_module(layer, tp_mesh, tp_plan)
# --- FSDP ---
fsdp_model = fully_shard(model_part, mesh=dp_mesh)
optimizer = optim.AdamW(fsdp_model.parameters(), lr=1e-3)
# --- Training Loop ---
for step in range(5):
optimizer.zero_grad()
# Settings
batch_size = 4
seq_len = 8
hidden_dim = 32
# --- STAGE 0 ---
if pp_rank == 0:
torch.manual_seed(rank + step)
inputs = torch.randn(batch_size, seq_len, hidden_dim)
output_activations = fsdp_model(inputs)
# Send to Peer (Simple Mapping: 0->4, 1->5, etc)
peer_rank = rank + 4
dist.send(output_activations.detach(), dst=peer_rank)
# Receive Grads
grad_buffer = torch.zeros_like(output_activations)
dist.recv(grad_buffer, src=peer_rank)
output_activations.backward(grad_buffer)
# --- STAGE 1 ---
elif pp_rank == 1:
recv_buffer = torch.zeros(batch_size, seq_len, hidden_dim)
peer_rank = rank - 4
dist.recv(recv_buffer, src=peer_rank)
recv_buffer.requires_grad = True
outputs = fsdp_model(recv_buffer)
# Loss
outputs = outputs.view(-1, 10)
targets = torch.randint(0, 10, (batch_size * seq_len,))
loss = nn.CrossEntropyLoss()(outputs, targets)
if step % 2 == 0:
print(f"[Rank {rank}] Step {step} | Loss: {loss.item():.4f}")
loss.backward()
dist.send(recv_buffer.grad, dst=peer_rank)
optimizer.step()
# --- Shutdown ---
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
world_size = 8
mp.set_start_method("spawn", force=True)
print(f"Running Fixed 3D Parallel Training on {world_size} Processes...")
mp.spawn(train_step, args=(world_size,), nprocs=world_size, join=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment