Created
December 2, 2025 16:18
-
-
Save zhuangh/2ebf56e9d5011cd212961639fa3a71b1 to your computer and use it in GitHub Desktop.
3d_parallelism
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.multiprocessing as mp | |
| import torch.distributed as dist | |
| from torch.distributed.device_mesh import init_device_mesh | |
| from torch.distributed.fsdp import fully_shard | |
| from torch.distributed.tensor.parallel import ( | |
| parallelize_module, | |
| ColwiseParallel, | |
| RowwiseParallel | |
| ) | |
| # ========================================== | |
| # 1. FIXED Model Definition | |
| # ========================================== | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, hidden_dim, ffn_dim): | |
| super().__init__() | |
| # FIX: Split 'Attn' into Up/Down to satisfy TP "Sandwich" rule. | |
| # Flow: Replicated -> Colwise (Split) -> Rowwise (Replicated) | |
| self.attn_up = nn.Linear(hidden_dim, hidden_dim) | |
| self.attn_down = nn.Linear(hidden_dim, hidden_dim) | |
| # MLP Part | |
| self.fc1 = nn.Linear(hidden_dim, ffn_dim) | |
| self.relu = nn.ReLU() | |
| self.fc2 = nn.Linear(ffn_dim, hidden_dim) | |
| def forward(self, x): | |
| # 1. Attention Simulation (TP Safe) | |
| # x is [32] -> attn_up (Col) -> [16] -> attn_down (Row) -> [32] | |
| h = self.attn_up(x) | |
| h = self.attn_down(h) | |
| x = x + h # Residual 1 (Safe: 32 + 32) | |
| # 2. MLP Block (TP Safe) | |
| # x is [32] -> fc1 (Col) -> [32] -> fc2 (Row) -> [32] | |
| h = self.fc2(self.relu(self.fc1(x))) | |
| return x + h # Residual 2 | |
| # ========================================== | |
| # 2. The Training Function | |
| # ========================================== | |
| def train_step(rank, world_size): | |
| # --- Setup --- | |
| os.environ["MASTER_ADDR"] = "localhost" | |
| os.environ["MASTER_PORT"] = "12377" | |
| dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
| # --- 3D Mesh --- | |
| mesh_3d = init_device_mesh("cpu", (2, 2, 2), mesh_dim_names=("pp", "dp", "tp")) | |
| tp_mesh = mesh_3d["tp"] | |
| dp_mesh = mesh_3d["dp"] | |
| pp_mesh = mesh_3d["pp"] | |
| pp_rank = mesh_3d.get_coordinate()[0] | |
| # --- Model Init --- | |
| model_part = None | |
| if pp_rank == 0: | |
| model_part = nn.Sequential( | |
| TransformerBlock(32, 64), | |
| TransformerBlock(32, 64) | |
| ) | |
| elif pp_rank == 1: | |
| model_part = nn.Sequential( | |
| TransformerBlock(32, 64), | |
| TransformerBlock(32, 64), | |
| nn.Linear(32, 10) # Head (Keep replicated, no TP) | |
| ) | |
| # --- FIXED TP Plan --- | |
| tp_plan = { | |
| # Attention Sandwich | |
| "attn_up": ColwiseParallel(), | |
| "attn_down": RowwiseParallel(), | |
| # MLP Sandwich | |
| "fc1": ColwiseParallel(), | |
| "fc2": RowwiseParallel(), | |
| # NOTE: We do NOT parallelize the final "Head" (Linear(32,10)). | |
| # This keeps it replicated so CrossEntropyLoss sees full [B, 10] output. | |
| } | |
| for layer in model_part: | |
| parallelize_module(layer, tp_mesh, tp_plan) | |
| # --- FSDP --- | |
| fsdp_model = fully_shard(model_part, mesh=dp_mesh) | |
| optimizer = optim.AdamW(fsdp_model.parameters(), lr=1e-3) | |
| # --- Training Loop --- | |
| for step in range(5): | |
| optimizer.zero_grad() | |
| # Settings | |
| batch_size = 4 | |
| seq_len = 8 | |
| hidden_dim = 32 | |
| # --- STAGE 0 --- | |
| if pp_rank == 0: | |
| torch.manual_seed(rank + step) | |
| inputs = torch.randn(batch_size, seq_len, hidden_dim) | |
| output_activations = fsdp_model(inputs) | |
| # Send to Peer (Simple Mapping: 0->4, 1->5, etc) | |
| peer_rank = rank + 4 | |
| dist.send(output_activations.detach(), dst=peer_rank) | |
| # Receive Grads | |
| grad_buffer = torch.zeros_like(output_activations) | |
| dist.recv(grad_buffer, src=peer_rank) | |
| output_activations.backward(grad_buffer) | |
| # --- STAGE 1 --- | |
| elif pp_rank == 1: | |
| recv_buffer = torch.zeros(batch_size, seq_len, hidden_dim) | |
| peer_rank = rank - 4 | |
| dist.recv(recv_buffer, src=peer_rank) | |
| recv_buffer.requires_grad = True | |
| outputs = fsdp_model(recv_buffer) | |
| # Loss | |
| outputs = outputs.view(-1, 10) | |
| targets = torch.randint(0, 10, (batch_size * seq_len,)) | |
| loss = nn.CrossEntropyLoss()(outputs, targets) | |
| if step % 2 == 0: | |
| print(f"[Rank {rank}] Step {step} | Loss: {loss.item():.4f}") | |
| loss.backward() | |
| dist.send(recv_buffer.grad, dst=peer_rank) | |
| optimizer.step() | |
| # --- Shutdown --- | |
| dist.barrier() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| world_size = 8 | |
| mp.set_start_method("spawn", force=True) | |
| print(f"Running Fixed 3D Parallel Training on {world_size} Processes...") | |
| mp.spawn(train_step, args=(world_size,), nprocs=world_size, join=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment