Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save belisarius222/13ccf31ee292a4ee7b5e2303f320a0ab to your computer and use it in GitHub Desktop.

Select an option

Save belisarius222/13ccf31ee292a4ee7b5e2303f320a0ab to your computer and use it in GitHub Desktop.
Swap-FFN Fused vs Unfused Decode: Detailed Execution Trace (A100)

Swap-FFN Fused vs Unfused Decode: Execution Trace

Detailed execution trace comparing the fused Triton kernel path vs the unfused PyTorch/cuBLAS path for Swap-FFN decode inference.

Config: d_model=1024, d_ff=4096, core_count=512, n_events=8, k_ffn=2, d_router=64, collapse_k=8, collapse_r=224

Hardware: A100-SXM4-80GB, 108 SMs, 40 MB L2 cache


UNFUSED PATH

Entry: SwapFFNModule.forward_run_training_token_step

Step 0: Setup (_run_training)

x_in: [B, D=1024]  →  reshaped to [B, 1, D]
  • _prepare_state: returns ram_state [B, 512, 4, 1024], stream_map [B, 512] (uint8)
  • _build_core_film: core_embed.weight [512, 32]core_scale [512, 1024] via Linear(32→1024), plus 1.0
  • _build_core_shift: core_embed.weight [512, 32]core_shift [512, 1024] via Linear(32→1024)
  • _build_core_alpha: core_embed.weight [512, 32]core_alpha [512, 8] via Linear(32→8), then softmax

These are 3 small matmuls: [512, 32] × [1024, 32]^T, etc. Negligible cost.

Step 1: Core broadcast (_token_step)

core_x = x_tok[:, None, :] * core_scale[None, :, :] + core_shift[None, :, :]
# [B, 1, 1024] * [1, 512, 1024] + [1, 512, 1024] → [B, 512, 1024]

One elementwise kernel. Broadcasts the single token across all 512 cores with per-core affine transform.

GPU ops: 1 kernel launch (elementwise multiply+add)

Step 2: Event loop (8 iterations of _swap_once + _ffn_step)

Each iteration of the event loop:

2a. _swap_once

# Router normalization
x_norm = _norm_core(core_x) * router_norm     # [B, 512, 1024]

GPU ops: 2 kernels (reduce for norm, elementwise for scale+multiply)

# Router projection
z = einsum('bcd,md->bcm', x_norm, router_w)   # [B, 512, 64]

GPU ops: 1 cuBLAS GEMM — reshapes to [B*512, 1024] × [64, 1024]^T → [B*512, 64]. Reads router_w (64×1024×2 = 128 KB) once for all 512 cores.

# Swap decision
swap_logit = einsum('bcm,m->bc', z, router_swap_w) + router_swap_b  # [B, 512]
gate = sigmoid(swap_logit)
gate_hard = gate > 0.5

GPU ops: ~2 kernels (reduce + elementwise sigmoid)

# Bank selection
bank_logits = einsum('bcm,mk->bck', z, bank_w) + bank_b  # [B, 512, 4]
selected_logical = argmax(bank_probs, dim=-1)              # [B, 512]

GPU ops: 1 GEMM ([B*512, 64] × [4, 64]^T) + 1-2 elementwise/reduce kernels

# Stream lane decode + physical bank mapping
lane0, lane1, lane2, lane3 = decode_stream_lanes(stream_map)
selected_physical = where(selected_logical == 0, lane0, where(...))

GPU ops: ~4 elementwise kernels (bitwise ops + nested where)

# RAM read: gather selected bank
selected_bank = gather(ram_state, dim=2, index=selected_physical[..., None, None].expand(...))
# [B, 512, 1, 1024] → squeeze → [B, 512, 1024]

GPU ops: 1 gather kernel

# Swap x with selected bank
x = where(gate_hard[:,:,None], selected_bank, x)           # [B, 512, 1024]
# Write x_before back to RAM
ram_new0 = where(mask0[:,:,None], x_prev, ram_state[:,:,0,:])  # ×4 banks
ram_state_next = stack([ram_new0, ram_new1, ram_new2, ram_new3], dim=2)

GPU ops: ~5 kernels (4 conditional writes + 1 stack)

# Update stream map lanes
lane0_next = where(gate_hard, selected_physical, lane0)
# ... pack back to uint8
stream_map_next = pack_stream_lanes(lane0_next, lane1_next, lane2_next, lane3_next)

GPU ops: ~4 kernels (3 where + 1 pack)

Subtotal for _swap_once: ~20 kernel launches

2b. _ffn_step — called K_FFN=2 times per event

x_norm = _norm_core(core_x) * self.norm.weight    # [B, 512, 1024]
gate = F.silu(F.linear(x_norm, w1.weight))         # [B, 512, 4096]
up = F.linear(x_norm, w3.weight)                   # [B, 512, 4096]
out = F.linear(gate * up, w2.weight)               # [B, 512, 1024]

The critical part: Each F.linear call on [B, 512, 1024] with weight [4096, 1024]:

  • PyTorch reshapes input to [B*512, 1024]
  • Calls cuBLAS GEMM: [B*512, 1024] × [4096, 1024]^T → [B*512, 4096]
  • Reads w1 weights (4096×1024×2 = 8 MB) ONCE for all 512 cores
  • cuBLAS tiles the computation across SMs, with each SM processing a block of rows (cores) × a block of output columns, loading weight tiles from L2/DRAM and reusing them across rows

GPU ops per FFN step: 3 GEMMs + 2 elementwise (norm, silu*up) = 5 kernels

Subtotal per event: 20 (swap) + 2 × 5 (FFN) = 30 kernel launches

Subtotal for all 8 events: 8 × 30 = 240 kernel launches

Step 3: Collapse (_collapse)

pooled = einsum('bcd,ck->bkd', core_x, core_alpha)        # [B, 8, 1024]
basis_prod = einsum('bkd,krd->bkr', pooled, collapse_basis) # [B, 8, 224]
z = basis_prod.sum(dim=1)                                    # [B, 224]
z = F.silu(z)                                                # [B, 224]
delta = F.linear(z, collapse_decoder)                        # [B, 1024]

GPU ops: ~5 kernels (2 GEMMs, 1 reduce, 1 silu, 1 linear)

Total unfused per decode token

  • ~3 setup kernels
  • ~240 kernels (event loop)
  • ~5 collapse kernels
  • ~248 kernel launches total

Weight DRAM traffic per decode token

Each FFN step reads w1, w3, w2 once:

  • 3 × 8 MB = 24 MB per FFN step
  • 16 FFN steps (8 events × 2 k_ffn) × 24 MB = 384 MB total weight reads

The weights likely stay hot in L2 (40 MB on A100) across the 512 "rows" within each GEMM call, so actual DRAM traffic per GEMM ≈ 8 MB (weight size). Between successive GEMM calls (w1→w3→w2→w1→...), the previous weight gets evicted, so each weight is re-fetched from DRAM each time.


FUSED PATH

Entry: SwapFFNModule.forward_run_fused_decodeswap_ffn_fused_decodeswap_ffn_fused_decode_triton_launch_swap_ffn_decode_kernel

Step 0: Setup (Python-side)

Same core_scale/core_shift/core_alpha computation as unfused. Then:

# Allocate workspaces
ffn_workspace: [B * 512, 4096 + 1024] float32    # ~10 MB for B=1
collapse_workspace: [B, 224] float32               # tiny
collapse_workspace.zero_()                          # must be zeroed for atomic_add

Step 1: Single Triton kernel launch

grid = (B * 512,)   # 512 blocks for B=1
_swap_ffn_decode_kernel[(512,)](...)
# num_warps=4, num_stages=2
# Each block: 4 warps × 32 threads = 128 threads

One kernel launch. 512 blocks get scheduled across 108 SMs on A100. Each block processes one (batch_item, core) pair entirely independently.

Inside each block (one of 512):

1a. Load input + core affine

x = load x_ptr[b, 0:1024]                    # 2 KB (bf16)
core_scale = load core_scale_ptr[c, 0:1024]   # 2 KB
core_shift = load core_shift_ptr[c, 0:1024]   # 2 KB
x = x * core_scale + core_shift               # in registers

1b. Load stream map + decode lanes

stream_u8 = load stream_in[b, c]              # 1 byte
lane0 = stream_u8 & 0x03
lane1 = (stream_u8 >> 2) & 0x03
lane2 = (stream_u8 >> 4) & 0x03
lane3 = (stream_u8 >> 6) & 0x03

1c. Load all 4 RAM banks

ram0 = load ram[b, c, 0, 0:1024]   # 2 KB
ram1 = load ram[b, c, 1, 0:1024]   # 2 KB
ram2 = load ram[b, c, 2, 0:1024]   # 2 KB
ram3 = load ram[b, c, 3, 0:1024]   # 2 KB

All of x, ram0-3, core_scale, core_shift live in registers (float32). That's 7 × 1024 floats = 7168 floats = 28 KB of register state per block. With 128 threads, that's ~224 bytes/thread → plus loop variables, intermediates, etc. → 255 registers/thread (the ncu-measured value). This is at the hardware limit.

1d. Event loop (8 iterations)

Each iteration:

Router:

# RMS norm
x_norm = x * rsqrt(sum(x*x) / 1024 + eps)
x_norm *= load(router_norm[0:1024])            # 2 KB, shared across all blocks

# Router matmul — SEQUENTIAL dot products
z[0:64] = zeros
for q in range(64):                             # SERIAL loop!
    w = load router_w[q, 0:1024]               # 2 KB per iteration
    z[q] = sum(w * x_norm)                     # reduce across D=1024

This is not a GEMM. It's 64 sequential vector dot products. Each loads a full row of router_w (2 KB). Total: 64 × 2 KB = 128 KB of weight reads per block per event.

Swap decision:

w_swap = load router_swap_w[0:64]              # 128 bytes
swap_logit = sum(z * w_swap) + load(router_swap_b)
gate = 1 / (1 + exp(-swap_logit))
gate_hard = gate > 0.5

Bank selection:

# 4 dot products of z[64] with bank_w columns
for bank in 0..3:
    bank_score[bank] = sum(z * load(bank_w[:, bank])) + load(bank_b[bank])
selected_logical = argmax(bank_score)
selected_physical = lane_map[selected_logical]

RAM swap:

selected_bank = load ram[b, c, selected_physical, :]  # already in registers!
x_before = x
x = where(gate_hard, selected_bank, x)
ram[selected_physical] = where(gate_hard, x_before, ram[selected_physical])
# Update lane mappings

All in registers, no DRAM access.

FFN — THE BOTTLENECK (K_FFN=2 times):

# RMS norm
x_norm = x * rsqrt(sum(x*x)/1024 + eps) * load(ffn_norm[0:1024])
store x_norm to ffn_workspace[pid, 4096:5120]   # 4 KB write

# Gate (w1) + Up (w3) matmul — TILED
for ff_base in range(0, 4096, 128):             # 32 tiles
    gate_acc[0:128] = zeros
    up_acc[0:128] = zeros
    for d_base in range(0, 1024, 128):          # 8 tiles
        x_seg = load workspace[pid, 4096+d_base : 4096+d_base+128]
        w1_tile = load w1[ff_base:ff_base+128, d_base:d_base+128]  # 32 KB (bf16)
        w3_tile = load w3[ff_base:ff_base+128, d_base:d_base+128]  # 32 KB
        gate_acc += w1_tile @ x_seg              # [128] += [128,128] × [128]
        up_acc += w3_tile @ x_seg
    silu_gate = gate_acc / (1 + exp(-gate_acc))
    inter = silu_gate * up_acc
    store inter to workspace[pid, ff_base:ff_base+128]

# Down (w2) matmul — TILED
for d_out_base in range(0, 1024, 128):          # 8 tiles
    out[0:128] = zeros
    for ff_base in range(0, 4096, 128):         # 32 tiles
        inter = load workspace[pid, ff_base:ff_base+128]
        w2_tile = load w2[d_out_base:d_out_base+128, ff_base:ff_base+128]  # 32 KB
        out += w2_tile @ inter
    store out to workspace[pid, 4096+d_out_base : ...]

x = load workspace[pid, 4096:5120]              # reload result

Per FFN step, per block, weight data loaded from DRAM/L2:

  • w1: 32 tiles × 8 subtiles × 32 KB = 8 MB
  • w3: same = 8 MB
  • w2: 8 tiles × 32 subtiles × 32 KB = 8 MB
  • Total: 24 MB per block per FFN step

Per event: 2 × 24 MB = 48 MB per block
Per kernel (8 events): 8 × 48 MB = 384 MB per block

With 512 blocks all loading the same weights: best case (perfect L2 sharing) would be 384 MB total. Worst case (no sharing) would be 512 × 384 MB = 192 GB.

The ncu measured 6.5 GB DRAM reads, which means L2 helps substantially but is far from perfect. With 512 blocks competing for 40 MB L2, weight tiles get evicted frequently.

1e. Write RAM outputs

store ram0 → ram_out[b, c, 0, 0:1024]
store ram1 → ram_out[b, c, 1, 0:1024]
store ram2 → ram_out[b, c, 2, 0:1024]
store ram3 → ram_out[b, c, 3, 0:1024]
store new_map → stream_out[b, c]

1f. Collapse accumulation

for r in range(0, 224, 32):                     # 7 r-blocks
    accum[0:32] = zeros
    for k in range(8):                          # collapse_k
        alpha = load core_alpha[c, k]           # scalar
        for d_base in range(0, 1024, 128):      # 8 d-blocks
            x_seg = load workspace[pid, 4096+d_base:...]
            basis_tile = load collapse_basis[k, r:r+32, d_base:d_base+128]  # 8 KB
            accum += alpha * (basis_tile @ x_seg)
    atomic_add(collapse_workspace[b, r:r+32], accum)

Per block: reads collapse_basis = 7 × 8 × 8 × 8 KB = 3.5 MB.
512 blocks × 3.5 MB = 1.8 GB (unique but shared via L2).

Step 2: Post-kernel Python

# Back in Python after kernel completes
if collapse_silu:
    collapse_workspace = F.silu(collapse_workspace)     # [B, 224]
delta = F.linear(collapse_workspace, collapse_decoder)  # [B, 224] × [1024, 224]^T → [B, 1024]

GPU ops: 2 more kernels (silu + gemm), trivially small.

Total fused per decode token

  • 1 Triton kernel launch (75.1 ms)
  • 2 tiny follow-up kernels (collapse_silu + collapse_decoder matmul)

Comparison

Unfused Fused
Kernel launches ~248 3
Weight loading strategy cuBLAS GEMM: load w1 once, apply to all 512 rows Each of 512 blocks loads w1 independently in tiles
Weight DRAM reads ~384 MB (16 GEMMs × 24 MB) ~6.5 GB (measured by ncu)
FFN compute cuBLAS (highly optimized tiling, tensor cores) Triton tiled loops, 128 threads, sequential
Register pressure Low (cuBLAS manages internally) 255 regs/thread (hardware limit), kills occupancy
SM utilization High (cuBLAS saturates SMs) 35.6% (low occupancy from register pressure)
Router compute 1 batched GEMM: [B*512, 1024] × [64, 1024]^T 64 sequential dot products per block (serial for q loop)
Throughput (B=1) 49.3 tok/s 13.7 tok/s

Why the fused kernel is slower

The fundamental issue: cuBLAS computes [512, 1024] × [4096, 1024]^T as a single operation, loading each weight tile once and broadcasting it across all 512 input rows (cores) using shared memory within each SM. The fused Triton kernel gives each core its own block, and each block redundantly loads the entire weight matrix. The weights are shared (same w1/w3/w2 for all cores), but the kernel structure doesn't exploit that — 512 blocks compete for 40 MB of L2 cache, causing massive DRAM re-reads.

Measured ncu metrics (commit d45767f)

Fused Triton _swap_ffn_decode_kernel:

  • Duration: 75.1 ms per launch
  • DRAM reads: 6.5 GB
  • DRAM writes: 262 MB
  • Registers/thread: 255
  • SM utilization: 35.6%
  • DRAM throughput: 4.4%
  • Instructions executed: 12.95 billion

Unfused cuBLAS gemm (decode hot path):

  • Duration: 12.3 µs per launch
  • DRAM reads: 8.5 MB per launch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment