Skip to content

Instantly share code, notes, and snippets.

@sourabh2k15
Last active March 2, 2026 01:49
Show Gist options
  • Select an option

  • Save sourabh2k15/f03db85e78bf82fca66bd1a84f97681d to your computer and use it in GitHub Desktop.

Select an option

Save sourabh2k15/f03db85e78bf82fca66bd1a84f97681d to your computer and use it in GitHub Desktop.
#@title Megatron Sharding
"""
╔══════════════════════════════════════════════════════════════════════════════╗
║ MEGATRON-LM: TENSOR PARALLELISM FOR FFN BLOCK ║
║ (Column + Row Parallel Linear Layers) ║
╚══════════════════════════════════════════════════════════════════════════════╝
The FFN block computes: Y = ReLU(X @ W1ᵀ) @ W2ᵀ
(simplified, no ReLU here for clean gradient verification)
UNSHARDED (single GPU):
┌───────────┐ W1 (2×3) ┌───────────┐ W2 (3×2) ┌───────────┐
│ X (3×3) │ ─────────────▶ │ O (3×2) │ ─────────────▶ │ Y (3×3) │
└───────────┘ └───────────┘ └───────────┘
SHARDED ACROSS 2 GPUs (Megatron style):
┌──────────────────────────────────────────────────────────┐
│ KEY IDEA: Shard W1 by ROWS, Shard W2 by COLUMNS │
│ → Each GPU handles half the "hidden" dimension │
│ → Final outputs are summed via All-Reduce │
└──────────────────────────────────────────────────────────┘
GPU 0: X @ W1_shard0ᵀ → O_shard0 → O_shard0 @ W2_shard0ᵀ → Y_partial0
GPU 1: X @ W1_shard1ᵀ → O_shard1 → O_shard1 @ W2_shard1ᵀ → Y_partial1
Y = Y_partial0 + Y_partial1
(All-Reduce / sum)
"""
import torch
torch.manual_seed(42)
# ═══════════════════════════════════════════════════════════════
# SETUP: Input and Weight Matrices
# ═══════════════════════════════════════════════════════════════
#
# X : (3×3) — batch of 3 tokens, each with 3 features
# W1 : (2×3) — projects 3 → 2 (hidden dim = 2)
# W2 : (3×2) — projects 2 → 3 (output dim = 3)
#
# ┌─────────────────┐
# │ X (3 × 3) │ ← 3 tokens × 3 input features
# └─────────────────┘
# │
# │ @ W1ᵀ (3×2) W1 (2×3):
# │ ┌──────────────┐
# ▼ │ row 0 (1×3) │ ← W1_shard0 → GPU 0
# ┌─────────────────┐ │ row 1 (1×3) │ ← W1_shard1 → GPU 1
# │ O (3 × 2) │ └──────────────┘
# └─────────────────┘ ↑ Row-parallel sharding
# │
# │ @ W2ᵀ (2×3) W2 (3×2):
# │ ┌──────────────┐
# ▼ │ col 0 (3×1) │ ← W2_shard0 → GPU 0
# ┌─────────────────┐ │ col 1 (3×1) │ ← W2_shard1 → GPU 1
# │ Y (3 × 3) │ └──────────────┘
# └─────────────────┘ ↑ Column-parallel sharding
X = torch.rand(3, 3)
W1 = torch.rand(2, 3)
W2 = torch.rand(3, 2)
X.requires_grad = True
W1.requires_grad = True
W2.requires_grad = True
# ═══════════════════════════════════════════════════════════════
# UNSHARDED FORWARD PASS (ground truth on a single GPU)
# ═══════════════════════════════════════════════════════════════
O = X @ W1.t() # 3x2
Y = O @ W2.t() # 3x3
W1.retain_grad()
W2.retain_grad()
# ═══════════════════════════════════════════════════════════════
# SHARDED FORWARD PASS
# ═══════════════════════════════════════════════════════════════
#
# STEP 1: Shard W1 along ROWS (row-parallel, dim=0)
# ────────────────────────────────────────────────────
#
# W1 (2×3): W1_shard0 (1×3): W1_shard1 (1×3):
# ┌───────────┐ ┌───────────┐ ┌───────────┐
# │ row 0 │ → │ row 0 │ GPU 0 │ row 1 │ GPU 1
# │ row 1 │ └───────────┘ └───────────┘
# └───────────┘
#
# Each GPU independently computes its slice of the hidden dim:
# GPU 0: O_shard0 = X @ W1_shard0ᵀ → (3×1)
# GPU 1: O_shard1 = X @ W1_shard1ᵀ → (3×1)
#
# Together they reconstruct O = [O_shard0 | O_shard1] → (3×2)
# (No communication needed here! Each GPU uses the full X)
W1_shard0, W1_shard1 = W1.chunk(chunks=2, dim=0) # 1x3 , RowParallel
O_shard0 = X @ W1_shard0.t() # 3x1 = 3x3 @ 3x1
O_shard1 = X @ W1_shard1.t() # 3x1 = 3x3 @ 3x1
W2_shard0, W2_shard1 = W2.chunk(chunks=2, dim=1) # 3x1 , ColParallel
Y_partial0 = O_shard0 @ W2_shard0.t() # 3x3 = 3x1 @ 1x3
Y_partial1 = O_shard1 @ W2_shard1.t()
Y_all_reduced = Y_partial0 + Y_partial1
# Verify if intermediate output O shards all-gathered matches unsharded O.
O_sharded = torch.concat([O_shard0, O_shard1], dim=1)
torch.testing.assert_close(O_sharded, O)
# ─────────────────────────────────────────────────────────────
# STEP 2: Shard W2 along COLUMNS (column-parallel, dim=1)
# ─────────────────────────────────────────────────────────────
#
# W2 (3×2): W2_shard0 (3×1): W2_shard1 (3×1):
# ┌──────────┐ ┌─────┐ ┌─────┐
# │ c0 | c1 │ → │ c0 │ GPU 0 │ c1 │ GPU 1
# │ c0 | c1 │ │ c0 │ │ c1 │
# │ c0 | c1 │ │ c0 │ │ c1 │
# └──────────┘ └─────┘ └─────┘
#
# Each GPU uses its local O_shard and W2_shard to compute a partial Y:
# GPU 0: Y_partial0 = O_shard0 @ W2_shard0ᵀ → (3×3)
# GPU 1: Y_partial1 = O_shard1 @ W2_shard1ᵀ → (3×3)
#
# ┌─────────────┐ + ┌─────────────┐ = ┌─────────────┐
# │ Y_partial0 │ │ Y_partial1 │ │ Y │
# │ (3×3) │ │ (3×3) │ │ (3×3) │
# └─────────────┘ └─────────────┘ └─────────────┘
# ↑ All-Reduce (sum) across GPUs ↑
# Verify if final all_reduced Y matches unsharded Y.
torch.testing.assert_close(Y_all_reduced, Y)
# ═══════════════════════════════════════════════════════════════
# UNSHARDED BACKWARD PASS (ground truth)
# ═══════════════════════════════════════════════════════════════
#
# Loss = mean(Y) → dL/dY = ones(3,3) / 9
#
# dL/dO = dL/dY @ W2 (chain rule through W2)
# dL/dX = dL/dO @ W1 (chain rule through W1)
# dL/dW1 = dL/dO.T @ X
# dL/dW2 = dL/dY.T @ O
# Y_all_reduced.retain_grad()
loss = Y.mean()
loss.backward()
# ═══════════════════════════════════════════════════════════════
# SHARDED BACKWARD PASS
# ═══════════════════════════════════════════════════════════════
#
# Forward recap (for gradient derivation):
# ┌──────────────────────────────────────────────────────┐
# │ GPU g: Y_partial_g = O_shard_g @ W2_shard_gᵀ │
# │ Y = Σ_g Y_partial_g (all-reduce) │
# │ loss = mean(Y) │
# └──────────────────────────────────────────────────────┘
#
# Since Y = Y_partial0 + Y_partial1 and loss = mean(Y):
# dL/dY_partial_g = dL/dY = ones(3,3)/9 ← same on both GPUs
# (broadcast, not reduce)
dL_by_dY_all_reduced = (torch.ones(3, 3)) / 9
dL_by_dY_partial0 = dL_by_dY_all_reduced # 3x3
dL_by_dY_partial1 = dL_by_dY_all_reduced # 3x3
dL_by_dO_shard0 = dL_by_dY_partial0 @ W2_shard0 # 3x1 = 3x3 @ 3x1
dL_by_dO_shard1 = dL_by_dY_partial1 @ W2_shard1 # 3x1 = 3x3 @ 3x1
# ─────────────────────────────────────────────────────────────
# Gradient w.r.t. X (needs All-Reduce!)
# ─────────────────────────────────────────────────────────────
#
# dL/dX_partial_g = dL/dO_shard_g @ W1_shard_g
#
# GPU 0: (3×1) @ (1×3) → (3×3)
# GPU 1: (3×1) @ (1×3) → (3×3)
#
# ┌──────────────────┐ ┌──────────────────┐
# │ dL/dX_partial0 │ + │ dL/dX_partial1 │
# │ (3×3) │ │ (3×3) │
# └──────────────────┘ └──────────────────┘
# ↑ All-Reduce (sum) → dL/dX (3×3)
#
# WHY? Because X contributed to O_shard0 AND O_shard1 via
# different weight shards, so its total gradient is the sum
# of contributions from all GPUs.
dL_by_dX_partial0 = dL_by_dO_shard0 @ W1_shard0 # 3x1 @ 1x3
dL_by_dX_partial1 = dL_by_dO_shard1 @ W1_shard1 # 3x1 @ 1x3
# all-reduce to get correct gradient
dL_by_dX = dL_by_dX_partial0 + dL_by_dX_partial1
torch.testing.assert_close(dL_by_dX, X.grad)
# gradients for weight shards required for optimizer update
# ─────────────────────────────────────────────────────────────
# Gradient w.r.t. weight shards (local, no communication!)
# ─────────────────────────────────────────────────────────────
#
# Each GPU only needs its local activations to update its weight shard.
# This is the beauty of Megatron sharding: weight gradients are local!
#
# dL/dW2_shard_g = dL/dY_partial_gᵀ @ O_shard_g
# GPU 0: (3×3)ᵀ @ (3×1) → (3×1) ← gradient for W2_shard0 on GPU 0
# GPU 1: (3×3)ᵀ @ (3×1) → (3×1) ← gradient for W2_shard1 on GPU 1
#
# dL/dW1_shard_g = dL/dO_shard_gᵀ @ X
# GPU 0: (3×1)ᵀ @ (3×3) → (1×3) ← gradient for W1_shard0 on GPU 0
# GPU 1: (3×1)ᵀ @ (3×3) → (1×3) ← gradient for W1_shard1 on GPU 1
dL_by_dW2_shard0 = dL_by_dY_partial0.t() @ O_shard0 # 3x1 = 3x3 @ 3x1
dL_by_dW2_shard1 = dL_by_dY_partial1.t() @ O_shard1 # 3x1 = 3x3 @ 3x1
dL_by_dW1_shard0 = dL_by_dO_shard0.t() @ X # 1x3 @ 3x3
dL_by_dW1_shard1 = dL_by_dO_shard1.t() @ X # 1x3 @ 3x3
torch.testing.assert_close(W1.grad, torch.vstack([dL_by_dW1_shard0, dL_by_dW1_shard1])) # row sharded W1
torch.testing.assert_close(W2.grad, torch.hstack([dL_by_dW2_shard0, dL_by_dW2_shard1])) # row sharded W1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment