Detailed execution trace comparing the fused Triton kernel path vs the unfused PyTorch/cuBLAS path for Swap-FFN decode inference.
Config: d_model=1024, d_ff=4096, core_count=512, n_events=8, k_ffn=2, d_router=64, collapse_k=8, collapse_r=224
Hardware: A100-SXM4-80GB, 108 SMs, 40 MB L2 cache
Entry: SwapFFNModule.forward → _run_training → _token_step
x_in: [B, D=1024] → reshaped to [B, 1, D]
_prepare_state: returnsram_state [B, 512, 4, 1024],stream_map [B, 512](uint8)_build_core_film:core_embed.weight [512, 32]→core_scale [512, 1024]via Linear(32→1024), plus 1.0_build_core_shift:core_embed.weight [512, 32]→core_shift [512, 1024]via Linear(32→1024)_build_core_alpha:core_embed.weight [512, 32]→core_alpha [512, 8]via Linear(32→8), then softmax
These are 3 small matmuls: [512, 32] × [1024, 32]^T, etc. Negligible cost.
core_x = x_tok[:, None, :] * core_scale[None, :, :] + core_shift[None, :, :]
# [B, 1, 1024] * [1, 512, 1024] + [1, 512, 1024] → [B, 512, 1024]One elementwise kernel. Broadcasts the single token across all 512 cores with per-core affine transform.
GPU ops: 1 kernel launch (elementwise multiply+add)
Each iteration of the event loop:
# Router normalization
x_norm = _norm_core(core_x) * router_norm # [B, 512, 1024]GPU ops: 2 kernels (reduce for norm, elementwise for scale+multiply)
# Router projection
z = einsum('bcd,md->bcm', x_norm, router_w) # [B, 512, 64]GPU ops: 1 cuBLAS GEMM — reshapes to [B*512, 1024] × [64, 1024]^T → [B*512, 64]. Reads router_w (64×1024×2 = 128 KB) once for all 512 cores.
# Swap decision
swap_logit = einsum('bcm,m->bc', z, router_swap_w) + router_swap_b # [B, 512]
gate = sigmoid(swap_logit)
gate_hard = gate > 0.5GPU ops: ~2 kernels (reduce + elementwise sigmoid)
# Bank selection
bank_logits = einsum('bcm,mk->bck', z, bank_w) + bank_b # [B, 512, 4]
selected_logical = argmax(bank_probs, dim=-1) # [B, 512]GPU ops: 1 GEMM ([B*512, 64] × [4, 64]^T) + 1-2 elementwise/reduce kernels
# Stream lane decode + physical bank mapping
lane0, lane1, lane2, lane3 = decode_stream_lanes(stream_map)
selected_physical = where(selected_logical == 0, lane0, where(...))GPU ops: ~4 elementwise kernels (bitwise ops + nested where)
# RAM read: gather selected bank
selected_bank = gather(ram_state, dim=2, index=selected_physical[..., None, None].expand(...))
# [B, 512, 1, 1024] → squeeze → [B, 512, 1024]GPU ops: 1 gather kernel
# Swap x with selected bank
x = where(gate_hard[:,:,None], selected_bank, x) # [B, 512, 1024]
# Write x_before back to RAM
ram_new0 = where(mask0[:,:,None], x_prev, ram_state[:,:,0,:]) # ×4 banks
ram_state_next = stack([ram_new0, ram_new1, ram_new2, ram_new3], dim=2)GPU ops: ~5 kernels (4 conditional writes + 1 stack)
# Update stream map lanes
lane0_next = where(gate_hard, selected_physical, lane0)
# ... pack back to uint8
stream_map_next = pack_stream_lanes(lane0_next, lane1_next, lane2_next, lane3_next)GPU ops: ~4 kernels (3 where + 1 pack)
Subtotal for _swap_once: ~20 kernel launches
x_norm = _norm_core(core_x) * self.norm.weight # [B, 512, 1024]
gate = F.silu(F.linear(x_norm, w1.weight)) # [B, 512, 4096]
up = F.linear(x_norm, w3.weight) # [B, 512, 4096]
out = F.linear(gate * up, w2.weight) # [B, 512, 1024]The critical part: Each F.linear call on [B, 512, 1024] with weight [4096, 1024]:
- PyTorch reshapes input to
[B*512, 1024] - Calls cuBLAS GEMM:
[B*512, 1024] × [4096, 1024]^T → [B*512, 4096] - Reads w1 weights (4096×1024×2 = 8 MB) ONCE for all 512 cores
- cuBLAS tiles the computation across SMs, with each SM processing a block of rows (cores) × a block of output columns, loading weight tiles from L2/DRAM and reusing them across rows
GPU ops per FFN step: 3 GEMMs + 2 elementwise (norm, silu*up) = 5 kernels
Subtotal per event: 20 (swap) + 2 × 5 (FFN) = 30 kernel launches
Subtotal for all 8 events: 8 × 30 = 240 kernel launches
pooled = einsum('bcd,ck->bkd', core_x, core_alpha) # [B, 8, 1024]
basis_prod = einsum('bkd,krd->bkr', pooled, collapse_basis) # [B, 8, 224]
z = basis_prod.sum(dim=1) # [B, 224]
z = F.silu(z) # [B, 224]
delta = F.linear(z, collapse_decoder) # [B, 1024]GPU ops: ~5 kernels (2 GEMMs, 1 reduce, 1 silu, 1 linear)
- ~3 setup kernels
- ~240 kernels (event loop)
- ~5 collapse kernels
- ~248 kernel launches total
Each FFN step reads w1, w3, w2 once:
- 3 × 8 MB = 24 MB per FFN step
- 16 FFN steps (8 events × 2 k_ffn) × 24 MB = 384 MB total weight reads
The weights likely stay hot in L2 (40 MB on A100) across the 512 "rows" within each GEMM call, so actual DRAM traffic per GEMM ≈ 8 MB (weight size). Between successive GEMM calls (w1→w3→w2→w1→...), the previous weight gets evicted, so each weight is re-fetched from DRAM each time.
Entry: SwapFFNModule.forward → _run_fused_decode → swap_ffn_fused_decode → swap_ffn_fused_decode_triton → _launch_swap_ffn_decode_kernel
Same core_scale/core_shift/core_alpha computation as unfused. Then:
# Allocate workspaces
ffn_workspace: [B * 512, 4096 + 1024] float32 # ~10 MB for B=1
collapse_workspace: [B, 224] float32 # tiny
collapse_workspace.zero_() # must be zeroed for atomic_addgrid = (B * 512,) # 512 blocks for B=1
_swap_ffn_decode_kernel[(512,)](...)
# num_warps=4, num_stages=2
# Each block: 4 warps × 32 threads = 128 threadsOne kernel launch. 512 blocks get scheduled across 108 SMs on A100. Each block processes one (batch_item, core) pair entirely independently.
x = load x_ptr[b, 0:1024] # 2 KB (bf16)
core_scale = load core_scale_ptr[c, 0:1024] # 2 KB
core_shift = load core_shift_ptr[c, 0:1024] # 2 KB
x = x * core_scale + core_shift # in registers
stream_u8 = load stream_in[b, c] # 1 byte
lane0 = stream_u8 & 0x03
lane1 = (stream_u8 >> 2) & 0x03
lane2 = (stream_u8 >> 4) & 0x03
lane3 = (stream_u8 >> 6) & 0x03
ram0 = load ram[b, c, 0, 0:1024] # 2 KB
ram1 = load ram[b, c, 1, 0:1024] # 2 KB
ram2 = load ram[b, c, 2, 0:1024] # 2 KB
ram3 = load ram[b, c, 3, 0:1024] # 2 KB
All of x, ram0-3, core_scale, core_shift live in registers (float32). That's 7 × 1024 floats = 7168 floats = 28 KB of register state per block. With 128 threads, that's ~224 bytes/thread → plus loop variables, intermediates, etc. → 255 registers/thread (the ncu-measured value). This is at the hardware limit.
Each iteration:
Router:
# RMS norm
x_norm = x * rsqrt(sum(x*x) / 1024 + eps)
x_norm *= load(router_norm[0:1024]) # 2 KB, shared across all blocks
# Router matmul — SEQUENTIAL dot products
z[0:64] = zeros
for q in range(64): # SERIAL loop!
w = load router_w[q, 0:1024] # 2 KB per iteration
z[q] = sum(w * x_norm) # reduce across D=1024
This is not a GEMM. It's 64 sequential vector dot products. Each loads a full row of router_w (2 KB). Total: 64 × 2 KB = 128 KB of weight reads per block per event.
Swap decision:
w_swap = load router_swap_w[0:64] # 128 bytes
swap_logit = sum(z * w_swap) + load(router_swap_b)
gate = 1 / (1 + exp(-swap_logit))
gate_hard = gate > 0.5
Bank selection:
# 4 dot products of z[64] with bank_w columns
for bank in 0..3:
bank_score[bank] = sum(z * load(bank_w[:, bank])) + load(bank_b[bank])
selected_logical = argmax(bank_score)
selected_physical = lane_map[selected_logical]
RAM swap:
selected_bank = load ram[b, c, selected_physical, :] # already in registers!
x_before = x
x = where(gate_hard, selected_bank, x)
ram[selected_physical] = where(gate_hard, x_before, ram[selected_physical])
# Update lane mappings
All in registers, no DRAM access.
FFN — THE BOTTLENECK (K_FFN=2 times):
# RMS norm
x_norm = x * rsqrt(sum(x*x)/1024 + eps) * load(ffn_norm[0:1024])
store x_norm to ffn_workspace[pid, 4096:5120] # 4 KB write
# Gate (w1) + Up (w3) matmul — TILED
for ff_base in range(0, 4096, 128): # 32 tiles
gate_acc[0:128] = zeros
up_acc[0:128] = zeros
for d_base in range(0, 1024, 128): # 8 tiles
x_seg = load workspace[pid, 4096+d_base : 4096+d_base+128]
w1_tile = load w1[ff_base:ff_base+128, d_base:d_base+128] # 32 KB (bf16)
w3_tile = load w3[ff_base:ff_base+128, d_base:d_base+128] # 32 KB
gate_acc += w1_tile @ x_seg # [128] += [128,128] × [128]
up_acc += w3_tile @ x_seg
silu_gate = gate_acc / (1 + exp(-gate_acc))
inter = silu_gate * up_acc
store inter to workspace[pid, ff_base:ff_base+128]
# Down (w2) matmul — TILED
for d_out_base in range(0, 1024, 128): # 8 tiles
out[0:128] = zeros
for ff_base in range(0, 4096, 128): # 32 tiles
inter = load workspace[pid, ff_base:ff_base+128]
w2_tile = load w2[d_out_base:d_out_base+128, ff_base:ff_base+128] # 32 KB
out += w2_tile @ inter
store out to workspace[pid, 4096+d_out_base : ...]
x = load workspace[pid, 4096:5120] # reload result
Per FFN step, per block, weight data loaded from DRAM/L2:
- w1: 32 tiles × 8 subtiles × 32 KB = 8 MB
- w3: same = 8 MB
- w2: 8 tiles × 32 subtiles × 32 KB = 8 MB
- Total: 24 MB per block per FFN step
Per event: 2 × 24 MB = 48 MB per block
Per kernel (8 events): 8 × 48 MB = 384 MB per block
With 512 blocks all loading the same weights: best case (perfect L2 sharing) would be 384 MB total. Worst case (no sharing) would be 512 × 384 MB = 192 GB.
The ncu measured 6.5 GB DRAM reads, which means L2 helps substantially but is far from perfect. With 512 blocks competing for 40 MB L2, weight tiles get evicted frequently.
store ram0 → ram_out[b, c, 0, 0:1024]
store ram1 → ram_out[b, c, 1, 0:1024]
store ram2 → ram_out[b, c, 2, 0:1024]
store ram3 → ram_out[b, c, 3, 0:1024]
store new_map → stream_out[b, c]
for r in range(0, 224, 32): # 7 r-blocks
accum[0:32] = zeros
for k in range(8): # collapse_k
alpha = load core_alpha[c, k] # scalar
for d_base in range(0, 1024, 128): # 8 d-blocks
x_seg = load workspace[pid, 4096+d_base:...]
basis_tile = load collapse_basis[k, r:r+32, d_base:d_base+128] # 8 KB
accum += alpha * (basis_tile @ x_seg)
atomic_add(collapse_workspace[b, r:r+32], accum)
Per block: reads collapse_basis = 7 × 8 × 8 × 8 KB = 3.5 MB.
512 blocks × 3.5 MB = 1.8 GB (unique but shared via L2).
# Back in Python after kernel completes
if collapse_silu:
collapse_workspace = F.silu(collapse_workspace) # [B, 224]
delta = F.linear(collapse_workspace, collapse_decoder) # [B, 224] × [1024, 224]^T → [B, 1024]GPU ops: 2 more kernels (silu + gemm), trivially small.
- 1 Triton kernel launch (75.1 ms)
- 2 tiny follow-up kernels (collapse_silu + collapse_decoder matmul)
| Unfused | Fused | |
|---|---|---|
| Kernel launches | ~248 | 3 |
| Weight loading strategy | cuBLAS GEMM: load w1 once, apply to all 512 rows | Each of 512 blocks loads w1 independently in tiles |
| Weight DRAM reads | ~384 MB (16 GEMMs × 24 MB) | ~6.5 GB (measured by ncu) |
| FFN compute | cuBLAS (highly optimized tiling, tensor cores) | Triton tiled loops, 128 threads, sequential |
| Register pressure | Low (cuBLAS manages internally) | 255 regs/thread (hardware limit), kills occupancy |
| SM utilization | High (cuBLAS saturates SMs) | 35.6% (low occupancy from register pressure) |
| Router compute | 1 batched GEMM: [B*512, 1024] × [64, 1024]^T |
64 sequential dot products per block (serial for q loop) |
| Throughput (B=1) | 49.3 tok/s | 13.7 tok/s |
The fundamental issue: cuBLAS computes [512, 1024] × [4096, 1024]^T as a single operation, loading each weight tile once and broadcasting it across all 512 input rows (cores) using shared memory within each SM. The fused Triton kernel gives each core its own block, and each block redundantly loads the entire weight matrix. The weights are shared (same w1/w3/w2 for all cores), but the kernel structure doesn't exploit that — 512 blocks compete for 40 MB of L2 cache, causing massive DRAM re-reads.
Fused Triton _swap_ffn_decode_kernel:
- Duration: 75.1 ms per launch
- DRAM reads: 6.5 GB
- DRAM writes: 262 MB
- Registers/thread: 255
- SM utilization: 35.6%
- DRAM throughput: 4.4%
- Instructions executed: 12.95 billion
Unfused cuBLAS gemm (decode hot path):
- Duration: 12.3 µs per launch
- DRAM reads: 8.5 MB per launch