Date: 2026-03-02
Machine: A1 (216.81.248.152), NVIDIA A100-SXM4-80GB, PyTorch 2.10.0+cu126, CUDA 12.6
Code: monarch repo, commit aa3bb6f (token-local swap-FFN state)
Checkpoints: Trained baseline and hybrid1 checkpoints from s3://voltcode-artifacts-17f9c348/runs/monarch-swap-ffn/20260302/
Results: s3://voltcode-artifacts-17f9c348/runs/swap-ffn-bench/a1-compiled-20260302/
Swap-FFN is a hybrid transformer architecture where some layers replace standard FFN blocks with a "swap" module: a router selects from multiple FFN banks, and tokens are shuffled across streams via a learned Benes network. This adds expressivity but is significantly more expensive at inference time due to the sequential swap events (8 events per layer, each with 2 FFN steps).
hybrid1 = 24-layer transformer with 1 swap-FFN layer (layer index configurable), d_model=1024, 16 heads, GQA with 4 KV heads, d_ff=4096, k_ffn=2, n_events=8, vocab=50K.
baseline = Same architecture but swap_layers=0 (pure transformer, no swap-FFN).
Both checkpoints were trained on the same data and saved as baseline.pt / hybrid1.pt.
Three optimizations applied to the unfused decode path (not the custom Triton fused kernel):
-
torch.compile (mode="default", Inductor backend): Compiles
_token_step_inference— a graph-break-free version of the swap-FFN token step that avoids.item()calls andgrad_checkpoint. Inductor fuses elementwise ops into Triton kernels and selects optimal CUTLASS GEMMs. -
Fused w1+w3 GEMM: Concatenates the gate projection (w1) and up projection (w3) weight matrices into a single
[2*d_ff, d_model]matrix. Instead of two separateF.linearcalls, does one matmul + chunk + SiLU*mul. Saves one kernel launch. -
Token-local state (commit
aa3bb6f): Each token starts with a fresh RAM state (zeroed) and identity stream map (0xE4) rather than carrying state from the previous token. This is an architectural change, not an optimization — it changes the model semantics so that swap-FFN routing is independent per token position.
- Batch size: 1 (single-stream inference)
- Prefill length: 2048 tokens
- Decode length: 64 tokens
- Dtype: bf16
- Timing: CUDA events with synchronization, sequential runs (no GPU contention)
- Schedule: Hard routing (tau=0.1, use_ste=True, use_hard=True)
- Script:
scripts/bench_decode_with_ckpt.py
| Variant | Decode tok/s | ms/token | avg_logprob | Code Version |
|---|---|---|---|---|
| Baseline (no swap-FFN) | 78.7 | 12.7 | -7.39 | pre-aa3bb6f |
| Hybrid1 unfused | 50.3 | 19.9 | -9.07 | pre-aa3bb6f |
| Hybrid1 + torch.compile | 65.7 | 15.2 | -9.18 | pre-aa3bb6f |
| Hybrid1 + compile + fused w13 | 66.5 | 15.0 | -9.18 | pre-aa3bb6f |
| Hybrid1 + compile + fused w13 | 56.0 | 17.9 | -9.18 | aa3bb6f (token-local) |
| Optimization | Speedup | Notes |
|---|---|---|
| torch.compile alone | +30.6% (50.3 → 65.7) | Inductor fuses RMSNorm, SiLU, softmax, gather/scatter into Triton kernels; picks optimal CUTLASS GEMM configs |
| + fused w13 GEMM | +1.2% (65.7 → 66.5) | Marginal at batch=1 — GEMMs are memory-latency-bound, saving one kernel launch barely helps |
| Token-local state (aa3bb6f) | -15.8% (66.5 → 56.0) | Per-token RAM state reset + identity stream map init adds overhead. This is a correctness/architecture change, not a regression to fix |
The avg_logprob shifts from -9.07 (unfused, no compile) to -9.18 (compiled variants) because _token_step_inference bypasses grad_checkpoint and passes float scalars instead of tensor controls. The difference is numerically insignificant — it's the same model weights producing slightly different routing due to floating-point ordering.
Profiled with: ncu --set roofline --launch-skip 2000 --launch-count 20
The torch.compile'd decode loop produces exactly 10 kernels per decode step that repeat identically for all 64 tokens. Profiling 20 kernels (2 steps) captures every unique kernel type.
| Duration (us) | Compute% | Memory% | Kernel | Role |
|---|---|---|---|---|
| 77.5 | 45.7% | 29.0% | ampere_bf16_s16816gemm_256x128 |
Fused w1+w3 GEMM (gate+up projection) |
| 31.0 | 57.3% | 40.6% | ampere_bf16_s16816gemm_128x128 |
w2 GEMM (down projection) |
| 10.6 | 28.4% | 56.1% | triton_fused_silu_split |
SiLU activation + chunk |
| 10.3 | 36.9% | 45.0% | triton_fused_select_stack_where |
Stream collapse/select |
| 10.0 | 22.9% | 15.5% | cublasLt::splitKreduce (bf16) |
GEMM split-K reduction |
| 9.8 | 46.6% | 16.1% | triton_fused_softmax_bmm_... |
Router: softmax + bank select + swap |
| 7.5 | 7.1% | 13.8% | cutlass_64x64_tn |
Attention QKV projection |
| 6.3 | 1.9% | 2.6% | cutlass_32x32_nn |
Small GEMM (router or output proj) |
| 5.3 | 14.7% | 19.6% | triton_fused_rmsnorm |
RMSNorm (norm + scale) |
| 4.9 | 1.5% | 1.5% | cublasGemv |
Small matrix-vector product |
| ~171 | Total per decode step |
- The fused w13 GEMM (77.5 us) dominates at 45% of step time. At batch=1, a
[1, 1024] x [8192, 1024]matmul is deeply latency-bound — only 45.7% compute utilization and 29% memory throughput. The A100 simply can't saturate its 312 TFLOPS or 2 TB/s bandwidth with such a small workload. - The w2 GEMM (31.0 us) is the second biggest cost. Together, the two FFN GEMMs account for 63% of the decode step.
- The router+swap kernel (9.8 us, 46.6% compute) is surprisingly well-utilized — Inductor did a good job fusing the softmax + argmax + gather + bitwise shuffle into one Triton kernel.
- At batch=1, the entire decode step is latency-bound, not compute-bound. Increasing batch size would amortize kernel launch overhead and improve utilization, but this benchmark targets the single-stream case.
The original approach (500 kernels with --set basic, 10 passes each) was taking 30+ minutes because it replayed thousands of identical kernel invocations. Since torch.compile produces a fixed 10-kernel pattern that repeats every decode step, profiling just 20 kernels (2 steps) with --set roofline completed in ~3 minutes and captured all unique kernel types with full roofline data.
s3://voltcode-artifacts-17f9c348/runs/swap-ffn-bench/a1-compiled-20260302/
├── baseline.json # Baseline benchmark metrics
├── hybrid1_unfused.json # Hybrid1 no optimizations
├── hybrid1_compiled.json # Hybrid1 + torch.compile
├── hybrid1_compiled_fused.json # Hybrid1 + compile + fused w13 (pre-aa3bb6f)
├── hybrid1_compiled_fused_v2.json # Hybrid1 + compile + fused w13 (aa3bb6f)
├── hybrid1_compiled_fused_v2_roofline.json # Same, produced during ncu roofline run
└── ncu/
├── hybrid1_compiled.ncu-rep # ncu profile: compile only (pre-aa3bb6f)
├── hybrid1_compiled_fused_v2.ncu-rep # ncu basic: compile+fused (aa3bb6f, partial)
└── hybrid1_compiled_fused_v2_roofline.ncu-rep # ncu roofline: compile+fused (aa3bb6f, 20 kernels)
All changes in src/monarch/swap_ffn_model.py and scripts/bench_decode_with_ckpt.py:
-
_token_step_inference: Graph-break-free version of_token_stepthat takes Python float/bool args instead of tensor controls, skipsgrad_checkpoint, and avoids.item()calls. Used only duringmodel.eval(). -
_ffn_stepfused path: When_w13_fusedis set, doesF.linear(x, w13_fused)→ chunk → SiLU*mul →F.linear(..., w2.weight)instead of separate w1/w3 calls. -
compile_for_inference()/fuse_w13(): Methods onSwapFFNModule(andfuse_w13_all()/compile_for_inference()onHybridDecoderLM) to enable these optimizations. -
bench_decode_with_ckpt.py: Added--compileand--fuse-w13flags, fixed checkpoint loading to handlemodel_state_dictkey.
torch.compile(mode="reduce-overhead")(CUDA graphs) doesn't work here because ram_state/stream_map tensors are mutated across calls — CUDA graphs require static tensor addresses.mode="default"(Inductor only) works fine.- Concurrent GPU benchmarks corrupt timing — always run sequentially on a shared GPU.
- ncu
--set full(47 passes/kernel) is impractical for compiled models. Use--set rooflinewith small--launch-countto profile unique kernel types only.