Skip to content

Instantly share code, notes, and snippets.

@aidando73
Created February 18, 2026 03:27
Show Gist options
  • Select an option

  • Save aidando73/ab3f91f618b29ff7edf5c0a0d6d5aedd to your computer and use it in GitHub Desktop.

Select an option

Save aidando73/ab3f91f618b29ff7edf5c0a0d6d5aedd to your computer and use it in GitHub Desktop.
Minimal repro: fp8 SS MMA with NoSwizzle vs SW32 vs SW64 vs SW128 on SM100 (B200). All four swizzle modes produce identical MMA throughput.
// Minimal repro: fp8 SS MMA with NoSwizzle vs SW32 vs SW64 vs SW128.
// All four produce identical MMA throughput (129 cycles/iter).
// Only depends on CUTLASS (no FlashMLA, no mla_decode_v21).
//
// Build & run (requires sm_100a GPU + CUTLASS source):
// nvcc -std=c++20 -O2 --generate-code=arch=compute_100a,code=[sm_100a] \
// -I third-party/cutlass/include \
// -o scripts/swizzle_mma_repro scripts/swizzle_mma_repro.cu \
// && ./scripts/swizzle_mma_repro
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cutlass/arch/barrier.h>
#include <cute/tensor.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include <cute/atom/mma_traits_sm100.hpp>
using namespace cute;
using e4m3 = cutlass::float_e4m3_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
constexpr int M = 64, N = 64, K = 128;
constexpr int BUF = M * K;
constexpr int BENCH_ITERS = 100;
// =====================================================================
// Inlined SS MMA atom (from fp8_mma_atom.h, no external dependency)
// =====================================================================
namespace cute {
template <class a_type, class b_type, class c_type,
int M_, int N_, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One,
UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F8F6F4_WS_SS_REPRO {
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a, uint64_t const& desc_b,
uint32_t const& tmem_c, uint32_t const& scaleC, uint64_t const& idescE) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p, 0;\n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b),
"r"(uint32_t(idescE >> 32)), "r"(scaleC));
}
};
template <class a_type, class b_type, class c_type,
int M_, int N_, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F8F6F4_WS_SS_REPRO<a_type, b_type, c_type,
M_, N_, a_major, b_major, a_neg, b_neg>> {
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;
static constexpr int K_ = 32;
using Shape_MNK = Shape<Int<M_>, Int<N_>, Int<K_>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape<_1, Shape<Int<M_>, Int<K_>>>,
Stride<_0, Stride<_1, Int<M_>>>>;
using BLayout = Layout<Shape<_1, Shape<Int<N_>, Int<K_>>>,
Stride<_0, Stride<_1, Int<N_>>>>;
using CLayout = Layout<Shape<_1, Shape<Int<M_>, Int<N_>>>,
Stride<_0, Stride<_1, Int<M_>>>>;
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M_, N_, a_major, b_major, a_neg, b_neg>();
template <class TD, class DL, class TA, class AL, class TB, class BL, class TC, class CL>
CUTE_HOST_DEVICE constexpr friend void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DL>& D, Tensor<TA, AL> const& A,
Tensor<TB, BL> const& B, Tensor<TC, CL> const& C) {
uint64_t desc_a = A[0], desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F8F6F4_WS_SS_REPRO<a_type, b_type, c_type, M_, N_, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c,
uint32_t(traits.accumulate_), idesc);
}
};
} // namespace cute
// =====================================================================
// Inlined helpers (from kerutils, pure PTX)
// =====================================================================
__device__ __forceinline__ void tcgen05_fence_before() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
__device__ __forceinline__ void tcgen05_fence_after() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
__device__ __forceinline__ void umma_commit(transac_bar_t& bar) {
uint32_t p = cute::cast_smem_ptr_to_uint(&bar);
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" :: "r"(p));
}
template <class TiledMMA, class TA, class TB, class TC>
__device__ void do_ss_mma(TiledMMA& mma, TA sA, TB sB, TC tC, bool clear) {
mma.accumulate_ = clear ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
auto thr = mma.get_slice(_0{});
auto fA = thr.partition_fragment_A(sA);
auto fB = thr.partition_fragment_B(sB);
CUTE_UNROLL
for (int k = 0; k < size<2>(fA); ++k) {
cute::gemm(mma, fA(_, _, k), fB(_, _, k), tC);
mma.accumulate_ = UMMA::ScaleOut::One;
}
}
// =====================================================================
// Types
// =====================================================================
using TiledMMA_SS = decltype(make_tiled_mma(
SM100_MMA_F8F6F4_WS_SS_REPRO<e4m3, e4m3, float, M, N, UMMA::Major::K, UMMA::Major::K>{}));
using Layout_SW32 = decltype(tile_to_shape(UMMA::Layout_K_SW32_Atom<e4m3>{},
Shape<Int<M>, Int<K>>{}, Step<_1, _2>{}));
using Layout_SW64 = decltype(tile_to_shape(UMMA::Layout_K_SW64_Atom<e4m3>{},
Shape<Int<M>, Int<K>>{}, Step<_1, _2>{}));
using Layout_SW128 = decltype(tile_to_shape(UMMA::Layout_K_SW128_Atom<e4m3>{},
Shape<Int<M>, Int<K>>{}, Step<_1, _2>{}));
struct SharedMem {
array_aligned<e4m3, BUF> a;
array_aligned<e4m3, BUF> b;
array_aligned<uint32_t, 1> tmem_addr;
transac_bar_t bar;
};
// =====================================================================
// Swizzled MMA kernel (CuTE descriptors)
// =====================================================================
template <typename Layout, int ITERS>
__global__ void __launch_bounds__(128, 1)
mma_swizzle(float* out, const e4m3* a_vals, const e4m3* b_vals, int64_t* cycles) {
extern __shared__ char sbuf[];
auto& sm = *reinterpret_cast<SharedMem*>(sbuf);
int warp = cutlass::canonical_warp_idx_sync();
int lane = threadIdx.x % 32;
int tid = threadIdx.x % 128;
if (warp == 0) {
if (elect_one_sync()) { sm.bar.init(1); cutlass::arch::fence_barrier_init(); }
cute::TMEM::Allocator1Sm().allocate(128, sm.tmem_addr.data());
cute::TMEM::Allocator1Sm().release_allocation_lock();
}
__syncthreads();
{
Tensor sA = make_tensor(make_smem_ptr(sm.a.data()), Layout{});
Tensor sB = make_tensor(make_smem_ptr(sm.b.data()), Layout{});
for (int i = tid; i < BUF; i += 128) {
int r = i / K, c = i % K;
sA(r, c) = a_vals[r * K + c];
}
for (int i = tid; i < BUF; i += 128) {
int r = i / K, c = i % K;
sB(r, c) = b_vals[r * K + c];
}
cutlass::arch::fence_view_async_shared();
}
__syncthreads();
int64_t t0 = 0, t1 = 0;
if (warp == 0 && elect_one_sync()) {
TiledMMA_SS mma{};
Tensor sA = make_tensor(make_smem_ptr(sm.a.data()), Layout{});
Tensor sB = make_tensor(make_smem_ptr(sm.b.data()), Layout{});
Tensor tC = partition_fragment_C(mma, Shape<Int<M>, Int<N>>{});
tC.data().get() = 0;
t0 = clock64();
do_ss_mma(mma, sA, sB, tC, true);
CUTE_UNROLL
for (int i = 1; i < ITERS; i++)
do_ss_mma(mma, sA, sB, tC, false);
umma_commit(sm.bar);
}
if (warp == 0) sm.bar.wait(0);
if (warp == 0 && elect_one_sync()) { t1 = clock64(); if (cycles) *cycles = t1 - t0; }
__syncthreads();
{
tcgen05_fence_after();
uint32_t cv[N / 2];
cute::SM100_TMEM_LOAD_32dp32b32x::copy(0, cv[0], cv[1], cv[2], cv[3], cv[4], cv[5],
cv[6], cv[7], cv[8], cv[9], cv[10], cv[11], cv[12], cv[13], cv[14], cv[15],
cv[16], cv[17], cv[18], cv[19], cv[20], cv[21], cv[22], cv[23], cv[24], cv[25],
cv[26], cv[27], cv[28], cv[29], cv[30], cv[31]);
cutlass::arch::fence_view_async_tmem_load();
tcgen05_fence_before();
int dp = warp * 32 + lane;
int noff = (dp < 64) ? 0 : (N / 2);
int m = (dp < 64) ? dp : dp - 64;
if (m < M)
for (int c = 0; c < N / 2; c++)
out[m * N + c + noff] = __uint_as_float(cv[c]);
}
if (warp == 0) cute::TMEM::Allocator1Sm().free(0, 128);
}
// =====================================================================
// No-swizzle MMA kernel (manual descriptors, interleaved SMEM format)
// =====================================================================
__device__ uint64_t make_desc_noswz(uint32_t addr, int lbo, int sbo) {
UMMA::SmemDescriptor d;
d.desc_ = 0;
d.start_address_ = addr >> 4;
d.leading_byte_offset_ = lbo;
d.stride_byte_offset_ = sbo;
d.version_ = 1;
d.lbo_mode_ = 0;
d.layout_type_ = uint8_t(UMMA::LayoutType::SWIZZLE_NONE);
return d.desc_;
}
__device__ int interleaved_off(int row, int col) {
return (row / 8) * (8 * K) + (col / 16) * 128 + (row % 8) * 16 + (col % 16);
}
template <int ITERS>
__global__ void __launch_bounds__(128, 1)
mma_noswizzle(float* out, const e4m3* a_vals, const e4m3* b_vals, int64_t* cycles) {
extern __shared__ char sbuf[];
auto& sm = *reinterpret_cast<SharedMem*>(sbuf);
int warp = cutlass::canonical_warp_idx_sync();
int lane = threadIdx.x % 32;
int tid = threadIdx.x % 128;
if (warp == 0) {
if (elect_one_sync()) { sm.bar.init(1); cutlass::arch::fence_barrier_init(); }
cute::TMEM::Allocator1Sm().allocate(128, sm.tmem_addr.data());
cute::TMEM::Allocator1Sm().release_allocation_lock();
}
__syncthreads();
{
auto* ar = reinterpret_cast<uint8_t*>(sm.a.data());
auto* br = reinterpret_cast<uint8_t*>(sm.b.data());
for (int i = tid; i < BUF; i += 128) {
int r = i / K, c = i % K;
ar[interleaved_off(r, c)] = reinterpret_cast<const uint8_t*>(a_vals)[r * K + c];
}
for (int i = tid; i < BUF; i += 128) {
int r = i / K, c = i % K;
br[interleaved_off(r, c)] = reinterpret_cast<const uint8_t*>(b_vals)[r * K + c];
}
cutlass::arch::fence_view_async_shared();
}
__syncthreads();
int64_t t0 = 0, t1 = 0;
if (warp == 0 && elect_one_sync()) {
uint32_t sa = static_cast<uint32_t>(__cvta_generic_to_shared(sm.a.data()));
uint32_t sb = static_cast<uint32_t>(__cvta_generic_to_shared(sm.b.data()));
constexpr int LBO = 8, SBO = (8 * K) / 16;
UMMA::InstrDescriptor id = UMMA::make_instr_desc<
e4m3, e4m3, float, M, N, UMMA::Major::K, UMMA::Major::K>();
uint32_t idhi = uint32_t(UMMA::make_runtime_instr_desc<>(id) >> 32);
t0 = clock64();
CUTE_UNROLL
for (int it = 0; it < ITERS; it++) {
uint32_t acc = (it == 0) ? 0 : 1;
CUTE_UNROLL
for (int k = 0; k < K / 32; k++) {
uint64_t da = make_desc_noswz(sa + k * 256, LBO, SBO);
uint64_t db = make_desc_noswz(sb + k * 256, LBO, SBO);
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p, 0;\n\t"
"}\n"
:: "r"(0u), "l"(da), "l"(db), "r"(idhi), "r"(acc));
acc = 1;
}
}
umma_commit(sm.bar);
}
if (warp == 0) sm.bar.wait(0);
if (warp == 0 && elect_one_sync()) { t1 = clock64(); if (cycles) *cycles = t1 - t0; }
__syncthreads();
{
tcgen05_fence_after();
uint32_t cv[N / 2];
cute::SM100_TMEM_LOAD_32dp32b32x::copy(0, cv[0], cv[1], cv[2], cv[3], cv[4], cv[5],
cv[6], cv[7], cv[8], cv[9], cv[10], cv[11], cv[12], cv[13], cv[14], cv[15],
cv[16], cv[17], cv[18], cv[19], cv[20], cv[21], cv[22], cv[23], cv[24], cv[25],
cv[26], cv[27], cv[28], cv[29], cv[30], cv[31]);
cutlass::arch::fence_view_async_tmem_load();
tcgen05_fence_before();
int dp = warp * 32 + lane;
int noff = (dp < 64) ? 0 : (N / 2);
int m_ = (dp < 64) ? dp : dp - 64;
if (m_ < M)
for (int c = 0; c < N / 2; c++)
out[m_ * N + c + noff] = __uint_as_float(cv[c]);
}
if (warp == 0) cute::TMEM::Allocator1Sm().free(0, 128);
}
// =====================================================================
// Host
// =====================================================================
void cpu_ref(const e4m3* a, const e4m3* b, float* c) {
for (int m = 0; m < M; m++)
for (int n = 0; n < N; n++) {
float s = 0;
for (int k = 0; k < K; k++)
s += float(a[m * K + k]) * float(b[n * K + k]);
c[m * N + n] = s;
}
}
bool check(const float* gpu, const float* ref, const char* name) {
int errs = 0; float mx = 0;
for (int i = 0; i < M * N; i++) {
float e = fabsf(gpu[i] - ref[i]);
mx = fmaxf(mx, e);
if (e > 1.0f && errs++ < 3)
printf(" MISMATCH [%d]: gpu=%.2f ref=%.2f\n", i, gpu[i], ref[i]);
}
printf("[%s] max_err=%.4f errors=%d/%d %s\n", name, mx, errs, M * N, errs ? "FAIL" : "PASS");
return errs == 0;
}
template <typename Kernel1, typename KernelN>
bool run(const char* name, Kernel1 k1, KernelN kN,
e4m3* da, e4m3* db, float* dc, e4m3* ha, e4m3* hb, float* hg, float* hr) {
constexpr size_t ss = sizeof(SharedMem);
cudaFuncSetAttribute(k1, cudaFuncAttributeMaxDynamicSharedMemorySize, ss);
cudaFuncSetAttribute(kN, cudaFuncAttributeMaxDynamicSharedMemorySize, ss);
int64_t *dcyc; cudaMalloc(&dcyc, 8);
for (int i = 0; i < M * K; i++) ha[i] = e4m3(1.0f);
for (int i = 0; i < N * K; i++) hb[i] = e4m3(1.0f);
cudaMemcpy(da, ha, M * K, cudaMemcpyHostToDevice);
cudaMemcpy(db, hb, N * K, cudaMemcpyHostToDevice);
cudaMemset(dc, 0, M * N * 4);
k1<<<1, 128, ss>>>(dc, da, db, nullptr);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
if (err != cudaSuccess) { printf("[%s] CUDA error: %s\n", name, cudaGetErrorString(err)); cudaFree(dcyc); return false; }
cudaMemcpy(hg, dc, M * N * 4, cudaMemcpyDeviceToHost);
cpu_ref(ha, hb, hr);
char buf[64]; snprintf(buf, 64, "%s_ones", name);
if (!check(hg, hr, buf)) { cudaFree(dcyc); return false; }
for (int w = 0; w < 5; w++) kN<<<1, 128, ss>>>(dc, da, db, nullptr);
cudaDeviceSynchronize();
kN<<<1, 128, ss>>>(dc, da, db, dcyc);
cudaDeviceSynchronize();
int64_t hcyc; cudaMemcpy(&hcyc, dcyc, 8, cudaMemcpyDeviceToHost);
printf("[%s] MMA-only: %ld cycles / %d iters = %ld cycles/iter\n",
name, hcyc, BENCH_ITERS, hcyc / BENCH_ITERS);
cudaFree(dcyc);
return true;
}
int main() {
e4m3 *ha = new e4m3[M * K], *hb = new e4m3[N * K];
float *hg = new float[M * N], *hr = new float[M * N];
e4m3 *da, *db; float *dc;
cudaMalloc(&da, M * K); cudaMalloc(&db, N * K); cudaMalloc(&dc, M * N * 4);
bool ok = true;
printf("=== fp8 SS MMA: NoSwizzle vs SW32 vs SW64 vs SW128 ===\n");
printf("=== M=%d N=%d K=%d, %d iters, device-side clock64() ===\n\n", M, N, K, BENCH_ITERS);
ok &= run("NoSwizzle", mma_noswizzle<1>, mma_noswizzle<BENCH_ITERS>, da, db, dc, ha, hb, hg, hr);
ok &= run("SW32", mma_swizzle<Layout_SW32, 1>, mma_swizzle<Layout_SW32, BENCH_ITERS>, da, db, dc, ha, hb, hg, hr);
ok &= run("SW64", mma_swizzle<Layout_SW64, 1>, mma_swizzle<Layout_SW64, BENCH_ITERS>, da, db, dc, ha, hb, hg, hr);
ok &= run("SW128", mma_swizzle<Layout_SW128, 1>, mma_swizzle<Layout_SW128, BENCH_ITERS>, da, db, dc, ha, hb, hg, hr);
printf("\n%s\n", ok ? "ALL PASSED" : "SOME FAILED");
delete[] ha; delete[] hb; delete[] hg; delete[] hr;
cudaFree(da); cudaFree(db); cudaFree(dc);
return ok ? 0 : 1;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment