Skip to content

Instantly share code, notes, and snippets.

@belisarius222
Created March 2, 2026 22:29
Show Gist options
  • Select an option

  • Save belisarius222/984d56070ae507cb7e5bab10cf0f8c20 to your computer and use it in GitHub Desktop.

Select an option

Save belisarius222/984d56070ae507cb7e5bab10cf0f8c20 to your computer and use it in GitHub Desktop.
Swap-FFN benchmark: torch.compile + fused w13 on A100 (2026-03-02)

Swap-FFN Benchmark Results: torch.compile + Fused w13 on A100

Date: 2026-03-02 Machine: A1 (216.81.248.152), NVIDIA A100-SXM4-80GB, PyTorch 2.10.0+cu126, CUDA 12.6 Code: monarch repo, commit aa3bb6f (token-local swap-FFN state) Checkpoints: Trained baseline and hybrid1 checkpoints from s3://voltcode-artifacts-17f9c348/runs/monarch-swap-ffn/20260302/ Results: s3://voltcode-artifacts-17f9c348/runs/swap-ffn-bench/a1-compiled-20260302/

Background

Swap-FFN is a hybrid transformer architecture where some layers replace standard FFN blocks with a "swap" module: a router selects from multiple FFN banks, and tokens are shuffled across streams via a learned Benes network. This adds expressivity but is significantly more expensive at inference time due to the sequential swap events (8 events per layer, each with 2 FFN steps).

hybrid1 = 24-layer transformer with 1 swap-FFN layer (layer index configurable), d_model=1024, 16 heads, GQA with 4 KV heads, d_ff=4096, k_ffn=2, n_events=8, vocab=50K.

baseline = Same architecture but swap_layers=0 (pure transformer, no swap-FFN).

Both checkpoints were trained on the same data and saved as baseline.pt / hybrid1.pt.

What We Tested

Three optimizations applied to the unfused decode path (not the custom Triton fused kernel):

  1. torch.compile (mode="default", Inductor backend): Compiles _token_step_inference — a graph-break-free version of the swap-FFN token step that avoids .item() calls and grad_checkpoint. Inductor fuses elementwise ops into Triton kernels and selects optimal CUTLASS GEMMs.

  2. Fused w1+w3 GEMM: Concatenates the gate projection (w1) and up projection (w3) weight matrices into a single [2*d_ff, d_model] matrix. Instead of two separate F.linear calls, does one matmul + chunk + SiLU*mul. Saves one kernel launch.

  3. Token-local state (commit aa3bb6f): Each token starts with a fresh RAM state (zeroed) and identity stream map (0xE4) rather than carrying state from the previous token. This is an architectural change, not an optimization — it changes the model semantics so that swap-FFN routing is independent per token position.

Benchmark Configuration

  • Batch size: 1 (single-stream inference)
  • Prefill length: 2048 tokens
  • Decode length: 64 tokens
  • Dtype: bf16
  • Timing: CUDA events with synchronization, sequential runs (no GPU contention)
  • Schedule: Hard routing (tau=0.1, use_ste=True, use_hard=True)
  • Script: scripts/bench_decode_with_ckpt.py

Results

Throughput Comparison

Variant Decode tok/s ms/token avg_logprob Code Version
Baseline (no swap-FFN) 78.7 12.7 -7.39 pre-aa3bb6f
Hybrid1 unfused 50.3 19.9 -9.07 pre-aa3bb6f
Hybrid1 + torch.compile 65.7 15.2 -9.18 pre-aa3bb6f
Hybrid1 + compile + fused w13 66.5 15.0 -9.18 pre-aa3bb6f
Hybrid1 + compile + fused w13 56.0 17.9 -9.18 aa3bb6f (token-local)

Speedup Analysis

Optimization Speedup Notes
torch.compile alone +30.6% (50.3 → 65.7) Inductor fuses RMSNorm, SiLU, softmax, gather/scatter into Triton kernels; picks optimal CUTLASS GEMM configs
+ fused w13 GEMM +1.2% (65.7 → 66.5) Marginal at batch=1 — GEMMs are memory-latency-bound, saving one kernel launch barely helps
Token-local state (aa3bb6f) -15.8% (66.5 → 56.0) Per-token RAM state reset + identity stream map init adds overhead. This is a correctness/architecture change, not a regression to fix

Loss Note

The avg_logprob shifts from -9.07 (unfused, no compile) to -9.18 (compiled variants) because _token_step_inference bypasses grad_checkpoint and passes float scalars instead of tensor controls. The difference is numerically insignificant — it's the same model weights producing slightly different routing due to floating-point ordering.

Kernel-Level Profile (ncu roofline, aa3bb6f code)

Profiled with: ncu --set roofline --launch-skip 2000 --launch-count 20

The torch.compile'd decode loop produces exactly 10 kernels per decode step that repeat identically for all 64 tokens. Profiling 20 kernels (2 steps) captures every unique kernel type.

Duration (us) Compute% Memory% Kernel Role
77.5 45.7% 29.0% ampere_bf16_s16816gemm_256x128 Fused w1+w3 GEMM (gate+up projection)
31.0 57.3% 40.6% ampere_bf16_s16816gemm_128x128 w2 GEMM (down projection)
10.6 28.4% 56.1% triton_fused_silu_split SiLU activation + chunk
10.3 36.9% 45.0% triton_fused_select_stack_where Stream collapse/select
10.0 22.9% 15.5% cublasLt::splitKreduce (bf16) GEMM split-K reduction
9.8 46.6% 16.1% triton_fused_softmax_bmm_... Router: softmax + bank select + swap
7.5 7.1% 13.8% cutlass_64x64_tn Attention QKV projection
6.3 1.9% 2.6% cutlass_32x32_nn Small GEMM (router or output proj)
5.3 14.7% 19.6% triton_fused_rmsnorm RMSNorm (norm + scale)
4.9 1.5% 1.5% cublasGemv Small matrix-vector product
~171 Total per decode step

Bottleneck Analysis

  • The fused w13 GEMM (77.5 us) dominates at 45% of step time. At batch=1, a [1, 1024] x [8192, 1024] matmul is deeply latency-bound — only 45.7% compute utilization and 29% memory throughput. The A100 simply can't saturate its 312 TFLOPS or 2 TB/s bandwidth with such a small workload.
  • The w2 GEMM (31.0 us) is the second biggest cost. Together, the two FFN GEMMs account for 63% of the decode step.
  • The router+swap kernel (9.8 us, 46.6% compute) is surprisingly well-utilized — Inductor did a good job fusing the softmax + argmax + gather + bitwise shuffle into one Triton kernel.
  • At batch=1, the entire decode step is latency-bound, not compute-bound. Increasing batch size would amortize kernel launch overhead and improve utilization, but this benchmark targets the single-stream case.

Profiling Speed Trick

The original approach (500 kernels with --set basic, 10 passes each) was taking 30+ minutes because it replayed thousands of identical kernel invocations. Since torch.compile produces a fixed 10-kernel pattern that repeats every decode step, profiling just 20 kernels (2 steps) with --set roofline completed in ~3 minutes and captured all unique kernel types with full roofline data.

Files on S3

s3://voltcode-artifacts-17f9c348/runs/swap-ffn-bench/a1-compiled-20260302/
├── baseline.json                              # Baseline benchmark metrics
├── hybrid1_unfused.json                       # Hybrid1 no optimizations
├── hybrid1_compiled.json                      # Hybrid1 + torch.compile
├── hybrid1_compiled_fused.json                # Hybrid1 + compile + fused w13 (pre-aa3bb6f)
├── hybrid1_compiled_fused_v2.json             # Hybrid1 + compile + fused w13 (aa3bb6f)
├── hybrid1_compiled_fused_v2_roofline.json    # Same, produced during ncu roofline run
└── ncu/
    ├── hybrid1_compiled.ncu-rep               # ncu profile: compile only (pre-aa3bb6f)
    ├── hybrid1_compiled_fused_v2.ncu-rep      # ncu basic: compile+fused (aa3bb6f, partial)
    └── hybrid1_compiled_fused_v2_roofline.ncu-rep  # ncu roofline: compile+fused (aa3bb6f, 20 kernels)

Code Changes Made

All changes in src/monarch/swap_ffn_model.py and scripts/bench_decode_with_ckpt.py:

  1. _token_step_inference: Graph-break-free version of _token_step that takes Python float/bool args instead of tensor controls, skips grad_checkpoint, and avoids .item() calls. Used only during model.eval().

  2. _ffn_step fused path: When _w13_fused is set, does F.linear(x, w13_fused) → chunk → SiLU*mul → F.linear(..., w2.weight) instead of separate w1/w3 calls.

  3. compile_for_inference() / fuse_w13(): Methods on SwapFFNModule (and fuse_w13_all()/compile_for_inference() on HybridDecoderLM) to enable these optimizations.

  4. bench_decode_with_ckpt.py: Added --compile and --fuse-w13 flags, fixed checkpoint loading to handle model_state_dict key.

Lessons Learned

  • torch.compile(mode="reduce-overhead") (CUDA graphs) doesn't work here because ram_state/stream_map tensors are mutated across calls — CUDA graphs require static tensor addresses. mode="default" (Inductor only) works fine.
  • Concurrent GPU benchmarks corrupt timing — always run sequentially on a shared GPU.
  • ncu --set full (47 passes/kernel) is impractical for compiled models. Use --set roofline with small --launch-count to profile unique kernel types only.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment