Last active
March 2, 2026 01:49
-
-
Save sourabh2k15/f03db85e78bf82fca66bd1a84f97681d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #@title Megatron Sharding | |
| """ | |
| ╔══════════════════════════════════════════════════════════════════════════════╗ | |
| ║ MEGATRON-LM: TENSOR PARALLELISM FOR FFN BLOCK ║ | |
| ║ (Column + Row Parallel Linear Layers) ║ | |
| ╚══════════════════════════════════════════════════════════════════════════════╝ | |
| The FFN block computes: Y = ReLU(X @ W1ᵀ) @ W2ᵀ | |
| (simplified, no ReLU here for clean gradient verification) | |
| UNSHARDED (single GPU): | |
| ┌───────────┐ W1 (2×3) ┌───────────┐ W2 (3×2) ┌───────────┐ | |
| │ X (3×3) │ ─────────────▶ │ O (3×2) │ ─────────────▶ │ Y (3×3) │ | |
| └───────────┘ └───────────┘ └───────────┘ | |
| SHARDED ACROSS 2 GPUs (Megatron style): | |
| ┌──────────────────────────────────────────────────────────┐ | |
| │ KEY IDEA: Shard W1 by ROWS, Shard W2 by COLUMNS │ | |
| │ → Each GPU handles half the "hidden" dimension │ | |
| │ → Final outputs are summed via All-Reduce │ | |
| └──────────────────────────────────────────────────────────┘ | |
| GPU 0: X @ W1_shard0ᵀ → O_shard0 → O_shard0 @ W2_shard0ᵀ → Y_partial0 | |
| GPU 1: X @ W1_shard1ᵀ → O_shard1 → O_shard1 @ W2_shard1ᵀ → Y_partial1 | |
| ↓ | |
| Y = Y_partial0 + Y_partial1 | |
| (All-Reduce / sum) | |
| """ | |
| import torch | |
| torch.manual_seed(42) | |
| # ═══════════════════════════════════════════════════════════════ | |
| # SETUP: Input and Weight Matrices | |
| # ═══════════════════════════════════════════════════════════════ | |
| # | |
| # X : (3×3) — batch of 3 tokens, each with 3 features | |
| # W1 : (2×3) — projects 3 → 2 (hidden dim = 2) | |
| # W2 : (3×2) — projects 2 → 3 (output dim = 3) | |
| # | |
| # ┌─────────────────┐ | |
| # │ X (3 × 3) │ ← 3 tokens × 3 input features | |
| # └─────────────────┘ | |
| # │ | |
| # │ @ W1ᵀ (3×2) W1 (2×3): | |
| # │ ┌──────────────┐ | |
| # ▼ │ row 0 (1×3) │ ← W1_shard0 → GPU 0 | |
| # ┌─────────────────┐ │ row 1 (1×3) │ ← W1_shard1 → GPU 1 | |
| # │ O (3 × 2) │ └──────────────┘ | |
| # └─────────────────┘ ↑ Row-parallel sharding | |
| # │ | |
| # │ @ W2ᵀ (2×3) W2 (3×2): | |
| # │ ┌──────────────┐ | |
| # ▼ │ col 0 (3×1) │ ← W2_shard0 → GPU 0 | |
| # ┌─────────────────┐ │ col 1 (3×1) │ ← W2_shard1 → GPU 1 | |
| # │ Y (3 × 3) │ └──────────────┘ | |
| # └─────────────────┘ ↑ Column-parallel sharding | |
| X = torch.rand(3, 3) | |
| W1 = torch.rand(2, 3) | |
| W2 = torch.rand(3, 2) | |
| X.requires_grad = True | |
| W1.requires_grad = True | |
| W2.requires_grad = True | |
| # ═══════════════════════════════════════════════════════════════ | |
| # UNSHARDED FORWARD PASS (ground truth on a single GPU) | |
| # ═══════════════════════════════════════════════════════════════ | |
| O = X @ W1.t() # 3x2 | |
| Y = O @ W2.t() # 3x3 | |
| W1.retain_grad() | |
| W2.retain_grad() | |
| # ═══════════════════════════════════════════════════════════════ | |
| # SHARDED FORWARD PASS | |
| # ═══════════════════════════════════════════════════════════════ | |
| # | |
| # STEP 1: Shard W1 along ROWS (row-parallel, dim=0) | |
| # ──────────────────────────────────────────────────── | |
| # | |
| # W1 (2×3): W1_shard0 (1×3): W1_shard1 (1×3): | |
| # ┌───────────┐ ┌───────────┐ ┌───────────┐ | |
| # │ row 0 │ → │ row 0 │ GPU 0 │ row 1 │ GPU 1 | |
| # │ row 1 │ └───────────┘ └───────────┘ | |
| # └───────────┘ | |
| # | |
| # Each GPU independently computes its slice of the hidden dim: | |
| # GPU 0: O_shard0 = X @ W1_shard0ᵀ → (3×1) | |
| # GPU 1: O_shard1 = X @ W1_shard1ᵀ → (3×1) | |
| # | |
| # Together they reconstruct O = [O_shard0 | O_shard1] → (3×2) | |
| # (No communication needed here! Each GPU uses the full X) | |
| W1_shard0, W1_shard1 = W1.chunk(chunks=2, dim=0) # 1x3 , RowParallel | |
| O_shard0 = X @ W1_shard0.t() # 3x1 = 3x3 @ 3x1 | |
| O_shard1 = X @ W1_shard1.t() # 3x1 = 3x3 @ 3x1 | |
| W2_shard0, W2_shard1 = W2.chunk(chunks=2, dim=1) # 3x1 , ColParallel | |
| Y_partial0 = O_shard0 @ W2_shard0.t() # 3x3 = 3x1 @ 1x3 | |
| Y_partial1 = O_shard1 @ W2_shard1.t() | |
| Y_all_reduced = Y_partial0 + Y_partial1 | |
| # Verify if intermediate output O shards all-gathered matches unsharded O. | |
| O_sharded = torch.concat([O_shard0, O_shard1], dim=1) | |
| torch.testing.assert_close(O_sharded, O) | |
| # ───────────────────────────────────────────────────────────── | |
| # STEP 2: Shard W2 along COLUMNS (column-parallel, dim=1) | |
| # ───────────────────────────────────────────────────────────── | |
| # | |
| # W2 (3×2): W2_shard0 (3×1): W2_shard1 (3×1): | |
| # ┌──────────┐ ┌─────┐ ┌─────┐ | |
| # │ c0 | c1 │ → │ c0 │ GPU 0 │ c1 │ GPU 1 | |
| # │ c0 | c1 │ │ c0 │ │ c1 │ | |
| # │ c0 | c1 │ │ c0 │ │ c1 │ | |
| # └──────────┘ └─────┘ └─────┘ | |
| # | |
| # Each GPU uses its local O_shard and W2_shard to compute a partial Y: | |
| # GPU 0: Y_partial0 = O_shard0 @ W2_shard0ᵀ → (3×3) | |
| # GPU 1: Y_partial1 = O_shard1 @ W2_shard1ᵀ → (3×3) | |
| # | |
| # ┌─────────────┐ + ┌─────────────┐ = ┌─────────────┐ | |
| # │ Y_partial0 │ │ Y_partial1 │ │ Y │ | |
| # │ (3×3) │ │ (3×3) │ │ (3×3) │ | |
| # └─────────────┘ └─────────────┘ └─────────────┘ | |
| # ↑ All-Reduce (sum) across GPUs ↑ | |
| # Verify if final all_reduced Y matches unsharded Y. | |
| torch.testing.assert_close(Y_all_reduced, Y) | |
| # ═══════════════════════════════════════════════════════════════ | |
| # UNSHARDED BACKWARD PASS (ground truth) | |
| # ═══════════════════════════════════════════════════════════════ | |
| # | |
| # Loss = mean(Y) → dL/dY = ones(3,3) / 9 | |
| # | |
| # dL/dO = dL/dY @ W2 (chain rule through W2) | |
| # dL/dX = dL/dO @ W1 (chain rule through W1) | |
| # dL/dW1 = dL/dO.T @ X | |
| # dL/dW2 = dL/dY.T @ O | |
| # Y_all_reduced.retain_grad() | |
| loss = Y.mean() | |
| loss.backward() | |
| # ═══════════════════════════════════════════════════════════════ | |
| # SHARDED BACKWARD PASS | |
| # ═══════════════════════════════════════════════════════════════ | |
| # | |
| # Forward recap (for gradient derivation): | |
| # ┌──────────────────────────────────────────────────────┐ | |
| # │ GPU g: Y_partial_g = O_shard_g @ W2_shard_gᵀ │ | |
| # │ Y = Σ_g Y_partial_g (all-reduce) │ | |
| # │ loss = mean(Y) │ | |
| # └──────────────────────────────────────────────────────┘ | |
| # | |
| # Since Y = Y_partial0 + Y_partial1 and loss = mean(Y): | |
| # dL/dY_partial_g = dL/dY = ones(3,3)/9 ← same on both GPUs | |
| # (broadcast, not reduce) | |
| dL_by_dY_all_reduced = (torch.ones(3, 3)) / 9 | |
| dL_by_dY_partial0 = dL_by_dY_all_reduced # 3x3 | |
| dL_by_dY_partial1 = dL_by_dY_all_reduced # 3x3 | |
| dL_by_dO_shard0 = dL_by_dY_partial0 @ W2_shard0 # 3x1 = 3x3 @ 3x1 | |
| dL_by_dO_shard1 = dL_by_dY_partial1 @ W2_shard1 # 3x1 = 3x3 @ 3x1 | |
| # ───────────────────────────────────────────────────────────── | |
| # Gradient w.r.t. X (needs All-Reduce!) | |
| # ───────────────────────────────────────────────────────────── | |
| # | |
| # dL/dX_partial_g = dL/dO_shard_g @ W1_shard_g | |
| # | |
| # GPU 0: (3×1) @ (1×3) → (3×3) | |
| # GPU 1: (3×1) @ (1×3) → (3×3) | |
| # | |
| # ┌──────────────────┐ ┌──────────────────┐ | |
| # │ dL/dX_partial0 │ + │ dL/dX_partial1 │ | |
| # │ (3×3) │ │ (3×3) │ | |
| # └──────────────────┘ └──────────────────┘ | |
| # ↑ All-Reduce (sum) → dL/dX (3×3) | |
| # | |
| # WHY? Because X contributed to O_shard0 AND O_shard1 via | |
| # different weight shards, so its total gradient is the sum | |
| # of contributions from all GPUs. | |
| dL_by_dX_partial0 = dL_by_dO_shard0 @ W1_shard0 # 3x1 @ 1x3 | |
| dL_by_dX_partial1 = dL_by_dO_shard1 @ W1_shard1 # 3x1 @ 1x3 | |
| # all-reduce to get correct gradient | |
| dL_by_dX = dL_by_dX_partial0 + dL_by_dX_partial1 | |
| torch.testing.assert_close(dL_by_dX, X.grad) | |
| # gradients for weight shards required for optimizer update | |
| # ───────────────────────────────────────────────────────────── | |
| # Gradient w.r.t. weight shards (local, no communication!) | |
| # ───────────────────────────────────────────────────────────── | |
| # | |
| # Each GPU only needs its local activations to update its weight shard. | |
| # This is the beauty of Megatron sharding: weight gradients are local! | |
| # | |
| # dL/dW2_shard_g = dL/dY_partial_gᵀ @ O_shard_g | |
| # GPU 0: (3×3)ᵀ @ (3×1) → (3×1) ← gradient for W2_shard0 on GPU 0 | |
| # GPU 1: (3×3)ᵀ @ (3×1) → (3×1) ← gradient for W2_shard1 on GPU 1 | |
| # | |
| # dL/dW1_shard_g = dL/dO_shard_gᵀ @ X | |
| # GPU 0: (3×1)ᵀ @ (3×3) → (1×3) ← gradient for W1_shard0 on GPU 0 | |
| # GPU 1: (3×1)ᵀ @ (3×3) → (1×3) ← gradient for W1_shard1 on GPU 1 | |
| dL_by_dW2_shard0 = dL_by_dY_partial0.t() @ O_shard0 # 3x1 = 3x3 @ 3x1 | |
| dL_by_dW2_shard1 = dL_by_dY_partial1.t() @ O_shard1 # 3x1 = 3x3 @ 3x1 | |
| dL_by_dW1_shard0 = dL_by_dO_shard0.t() @ X # 1x3 @ 3x3 | |
| dL_by_dW1_shard1 = dL_by_dO_shard1.t() @ X # 1x3 @ 3x3 | |
| torch.testing.assert_close(W1.grad, torch.vstack([dL_by_dW1_shard0, dL_by_dW1_shard1])) # row sharded W1 | |
| torch.testing.assert_close(W2.grad, torch.hstack([dL_by_dW2_shard0, dL_by_dW2_shard1])) # row sharded W1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment