Created
February 18, 2026 03:27
-
-
Save aidando73/ab3f91f618b29ff7edf5c0a0d6d5aedd to your computer and use it in GitHub Desktop.
Minimal repro: fp8 SS MMA with NoSwizzle vs SW32 vs SW64 vs SW128 on SM100 (B200). All four swizzle modes produce identical MMA throughput.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Minimal repro: fp8 SS MMA with NoSwizzle vs SW32 vs SW64 vs SW128. | |
| // All four produce identical MMA throughput (129 cycles/iter). | |
| // Only depends on CUTLASS (no FlashMLA, no mla_decode_v21). | |
| // | |
| // Build & run (requires sm_100a GPU + CUTLASS source): | |
| // nvcc -std=c++20 -O2 --generate-code=arch=compute_100a,code=[sm_100a] \ | |
| // -I third-party/cutlass/include \ | |
| // -o scripts/swizzle_mma_repro scripts/swizzle_mma_repro.cu \ | |
| // && ./scripts/swizzle_mma_repro | |
| #include <cstdio> | |
| #include <cstdlib> | |
| #include <cmath> | |
| #include <cstring> | |
| #include <cuda_fp8.h> | |
| #include <cuda_runtime.h> | |
| #include <cutlass/arch/barrier.h> | |
| #include <cute/tensor.hpp> | |
| #include <cute/arch/tmem_allocator_sm100.hpp> | |
| #include <cute/atom/mma_traits_sm100.hpp> | |
| using namespace cute; | |
| using e4m3 = cutlass::float_e4m3_t; | |
| using transac_bar_t = cutlass::arch::ClusterTransactionBarrier; | |
| constexpr int M = 64, N = 64, K = 128; | |
| constexpr int BUF = M * K; | |
| constexpr int BENCH_ITERS = 100; | |
| // ===================================================================== | |
| // Inlined SS MMA atom (from fp8_mma_atom.h, no external dependency) | |
| // ===================================================================== | |
| namespace cute { | |
| template <class a_type, class b_type, class c_type, | |
| int M_, int N_, UMMA::Major a_major, UMMA::Major b_major, | |
| UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, | |
| UMMA::ScaleIn b_neg = UMMA::ScaleIn::One> | |
| struct SM100_MMA_F8F6F4_WS_SS_REPRO { | |
| using DRegisters = void; | |
| using ARegisters = uint64_t[1]; | |
| using BRegisters = uint64_t[1]; | |
| using CRegisters = uint32_t[1]; | |
| CUTE_HOST_DEVICE static void | |
| fma(uint64_t const& desc_a, uint64_t const& desc_b, | |
| uint32_t const& tmem_c, uint32_t const& scaleC, uint64_t const& idescE) { | |
| asm volatile( | |
| "{\n\t" | |
| ".reg .pred p;\n\t" | |
| "setp.ne.b32 p, %4, 0;\n\t" | |
| "tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p, 0;\n\t" | |
| "}\n" | |
| : | |
| : "r"(tmem_c), "l"(desc_a), "l"(desc_b), | |
| "r"(uint32_t(idescE >> 32)), "r"(scaleC)); | |
| } | |
| }; | |
| template <class a_type, class b_type, class c_type, | |
| int M_, int N_, UMMA::Major a_major, UMMA::Major b_major, | |
| UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg> | |
| struct MMA_Traits<SM100_MMA_F8F6F4_WS_SS_REPRO<a_type, b_type, c_type, | |
| M_, N_, a_major, b_major, a_neg, b_neg>> { | |
| using ValTypeD = c_type; | |
| using ValTypeA = a_type; | |
| using ValTypeB = b_type; | |
| using ValTypeC = c_type; | |
| using FrgTypeA = UMMA::smem_desc<a_major>; | |
| using FrgTypeB = UMMA::smem_desc<b_major>; | |
| using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>; | |
| static constexpr int K_ = 32; | |
| using Shape_MNK = Shape<Int<M_>, Int<N_>, Int<K_>>; | |
| using ThrID = Layout<_1>; | |
| using ALayout = Layout<Shape<_1, Shape<Int<M_>, Int<K_>>>, | |
| Stride<_0, Stride<_1, Int<M_>>>>; | |
| using BLayout = Layout<Shape<_1, Shape<Int<N_>, Int<K_>>>, | |
| Stride<_0, Stride<_1, Int<N_>>>>; | |
| using CLayout = Layout<Shape<_1, Shape<Int<M_>, Int<N_>>>, | |
| Stride<_0, Stride<_1, Int<M_>>>>; | |
| UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; | |
| UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< | |
| a_type, b_type, c_type, M_, N_, a_major, b_major, a_neg, b_neg>(); | |
| template <class TD, class DL, class TA, class AL, class TB, class BL, class TC, class CL> | |
| CUTE_HOST_DEVICE constexpr friend void | |
| mma_unpack(MMA_Traits const& traits, | |
| Tensor<TD, DL>& D, Tensor<TA, AL> const& A, | |
| Tensor<TB, BL> const& B, Tensor<TC, CL> const& C) { | |
| uint64_t desc_a = A[0], desc_b = B[0]; | |
| uint32_t tmem_c = raw_pointer_cast(D.data()); | |
| uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); | |
| SM100_MMA_F8F6F4_WS_SS_REPRO<a_type, b_type, c_type, M_, N_, a_major, b_major, | |
| a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, | |
| uint32_t(traits.accumulate_), idesc); | |
| } | |
| }; | |
| } // namespace cute | |
| // ===================================================================== | |
| // Inlined helpers (from kerutils, pure PTX) | |
| // ===================================================================== | |
| __device__ __forceinline__ void tcgen05_fence_before() { | |
| asm volatile("tcgen05.fence::before_thread_sync;"); | |
| } | |
| __device__ __forceinline__ void tcgen05_fence_after() { | |
| asm volatile("tcgen05.fence::after_thread_sync;"); | |
| } | |
| __device__ __forceinline__ void umma_commit(transac_bar_t& bar) { | |
| uint32_t p = cute::cast_smem_ptr_to_uint(&bar); | |
| asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" :: "r"(p)); | |
| } | |
| template <class TiledMMA, class TA, class TB, class TC> | |
| __device__ void do_ss_mma(TiledMMA& mma, TA sA, TB sB, TC tC, bool clear) { | |
| mma.accumulate_ = clear ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; | |
| auto thr = mma.get_slice(_0{}); | |
| auto fA = thr.partition_fragment_A(sA); | |
| auto fB = thr.partition_fragment_B(sB); | |
| CUTE_UNROLL | |
| for (int k = 0; k < size<2>(fA); ++k) { | |
| cute::gemm(mma, fA(_, _, k), fB(_, _, k), tC); | |
| mma.accumulate_ = UMMA::ScaleOut::One; | |
| } | |
| } | |
| // ===================================================================== | |
| // Types | |
| // ===================================================================== | |
| using TiledMMA_SS = decltype(make_tiled_mma( | |
| SM100_MMA_F8F6F4_WS_SS_REPRO<e4m3, e4m3, float, M, N, UMMA::Major::K, UMMA::Major::K>{})); | |
| using Layout_SW32 = decltype(tile_to_shape(UMMA::Layout_K_SW32_Atom<e4m3>{}, | |
| Shape<Int<M>, Int<K>>{}, Step<_1, _2>{})); | |
| using Layout_SW64 = decltype(tile_to_shape(UMMA::Layout_K_SW64_Atom<e4m3>{}, | |
| Shape<Int<M>, Int<K>>{}, Step<_1, _2>{})); | |
| using Layout_SW128 = decltype(tile_to_shape(UMMA::Layout_K_SW128_Atom<e4m3>{}, | |
| Shape<Int<M>, Int<K>>{}, Step<_1, _2>{})); | |
| struct SharedMem { | |
| array_aligned<e4m3, BUF> a; | |
| array_aligned<e4m3, BUF> b; | |
| array_aligned<uint32_t, 1> tmem_addr; | |
| transac_bar_t bar; | |
| }; | |
| // ===================================================================== | |
| // Swizzled MMA kernel (CuTE descriptors) | |
| // ===================================================================== | |
| template <typename Layout, int ITERS> | |
| __global__ void __launch_bounds__(128, 1) | |
| mma_swizzle(float* out, const e4m3* a_vals, const e4m3* b_vals, int64_t* cycles) { | |
| extern __shared__ char sbuf[]; | |
| auto& sm = *reinterpret_cast<SharedMem*>(sbuf); | |
| int warp = cutlass::canonical_warp_idx_sync(); | |
| int lane = threadIdx.x % 32; | |
| int tid = threadIdx.x % 128; | |
| if (warp == 0) { | |
| if (elect_one_sync()) { sm.bar.init(1); cutlass::arch::fence_barrier_init(); } | |
| cute::TMEM::Allocator1Sm().allocate(128, sm.tmem_addr.data()); | |
| cute::TMEM::Allocator1Sm().release_allocation_lock(); | |
| } | |
| __syncthreads(); | |
| { | |
| Tensor sA = make_tensor(make_smem_ptr(sm.a.data()), Layout{}); | |
| Tensor sB = make_tensor(make_smem_ptr(sm.b.data()), Layout{}); | |
| for (int i = tid; i < BUF; i += 128) { | |
| int r = i / K, c = i % K; | |
| sA(r, c) = a_vals[r * K + c]; | |
| } | |
| for (int i = tid; i < BUF; i += 128) { | |
| int r = i / K, c = i % K; | |
| sB(r, c) = b_vals[r * K + c]; | |
| } | |
| cutlass::arch::fence_view_async_shared(); | |
| } | |
| __syncthreads(); | |
| int64_t t0 = 0, t1 = 0; | |
| if (warp == 0 && elect_one_sync()) { | |
| TiledMMA_SS mma{}; | |
| Tensor sA = make_tensor(make_smem_ptr(sm.a.data()), Layout{}); | |
| Tensor sB = make_tensor(make_smem_ptr(sm.b.data()), Layout{}); | |
| Tensor tC = partition_fragment_C(mma, Shape<Int<M>, Int<N>>{}); | |
| tC.data().get() = 0; | |
| t0 = clock64(); | |
| do_ss_mma(mma, sA, sB, tC, true); | |
| CUTE_UNROLL | |
| for (int i = 1; i < ITERS; i++) | |
| do_ss_mma(mma, sA, sB, tC, false); | |
| umma_commit(sm.bar); | |
| } | |
| if (warp == 0) sm.bar.wait(0); | |
| if (warp == 0 && elect_one_sync()) { t1 = clock64(); if (cycles) *cycles = t1 - t0; } | |
| __syncthreads(); | |
| { | |
| tcgen05_fence_after(); | |
| uint32_t cv[N / 2]; | |
| cute::SM100_TMEM_LOAD_32dp32b32x::copy(0, cv[0], cv[1], cv[2], cv[3], cv[4], cv[5], | |
| cv[6], cv[7], cv[8], cv[9], cv[10], cv[11], cv[12], cv[13], cv[14], cv[15], | |
| cv[16], cv[17], cv[18], cv[19], cv[20], cv[21], cv[22], cv[23], cv[24], cv[25], | |
| cv[26], cv[27], cv[28], cv[29], cv[30], cv[31]); | |
| cutlass::arch::fence_view_async_tmem_load(); | |
| tcgen05_fence_before(); | |
| int dp = warp * 32 + lane; | |
| int noff = (dp < 64) ? 0 : (N / 2); | |
| int m = (dp < 64) ? dp : dp - 64; | |
| if (m < M) | |
| for (int c = 0; c < N / 2; c++) | |
| out[m * N + c + noff] = __uint_as_float(cv[c]); | |
| } | |
| if (warp == 0) cute::TMEM::Allocator1Sm().free(0, 128); | |
| } | |
| // ===================================================================== | |
| // No-swizzle MMA kernel (manual descriptors, interleaved SMEM format) | |
| // ===================================================================== | |
| __device__ uint64_t make_desc_noswz(uint32_t addr, int lbo, int sbo) { | |
| UMMA::SmemDescriptor d; | |
| d.desc_ = 0; | |
| d.start_address_ = addr >> 4; | |
| d.leading_byte_offset_ = lbo; | |
| d.stride_byte_offset_ = sbo; | |
| d.version_ = 1; | |
| d.lbo_mode_ = 0; | |
| d.layout_type_ = uint8_t(UMMA::LayoutType::SWIZZLE_NONE); | |
| return d.desc_; | |
| } | |
| __device__ int interleaved_off(int row, int col) { | |
| return (row / 8) * (8 * K) + (col / 16) * 128 + (row % 8) * 16 + (col % 16); | |
| } | |
| template <int ITERS> | |
| __global__ void __launch_bounds__(128, 1) | |
| mma_noswizzle(float* out, const e4m3* a_vals, const e4m3* b_vals, int64_t* cycles) { | |
| extern __shared__ char sbuf[]; | |
| auto& sm = *reinterpret_cast<SharedMem*>(sbuf); | |
| int warp = cutlass::canonical_warp_idx_sync(); | |
| int lane = threadIdx.x % 32; | |
| int tid = threadIdx.x % 128; | |
| if (warp == 0) { | |
| if (elect_one_sync()) { sm.bar.init(1); cutlass::arch::fence_barrier_init(); } | |
| cute::TMEM::Allocator1Sm().allocate(128, sm.tmem_addr.data()); | |
| cute::TMEM::Allocator1Sm().release_allocation_lock(); | |
| } | |
| __syncthreads(); | |
| { | |
| auto* ar = reinterpret_cast<uint8_t*>(sm.a.data()); | |
| auto* br = reinterpret_cast<uint8_t*>(sm.b.data()); | |
| for (int i = tid; i < BUF; i += 128) { | |
| int r = i / K, c = i % K; | |
| ar[interleaved_off(r, c)] = reinterpret_cast<const uint8_t*>(a_vals)[r * K + c]; | |
| } | |
| for (int i = tid; i < BUF; i += 128) { | |
| int r = i / K, c = i % K; | |
| br[interleaved_off(r, c)] = reinterpret_cast<const uint8_t*>(b_vals)[r * K + c]; | |
| } | |
| cutlass::arch::fence_view_async_shared(); | |
| } | |
| __syncthreads(); | |
| int64_t t0 = 0, t1 = 0; | |
| if (warp == 0 && elect_one_sync()) { | |
| uint32_t sa = static_cast<uint32_t>(__cvta_generic_to_shared(sm.a.data())); | |
| uint32_t sb = static_cast<uint32_t>(__cvta_generic_to_shared(sm.b.data())); | |
| constexpr int LBO = 8, SBO = (8 * K) / 16; | |
| UMMA::InstrDescriptor id = UMMA::make_instr_desc< | |
| e4m3, e4m3, float, M, N, UMMA::Major::K, UMMA::Major::K>(); | |
| uint32_t idhi = uint32_t(UMMA::make_runtime_instr_desc<>(id) >> 32); | |
| t0 = clock64(); | |
| CUTE_UNROLL | |
| for (int it = 0; it < ITERS; it++) { | |
| uint32_t acc = (it == 0) ? 0 : 1; | |
| CUTE_UNROLL | |
| for (int k = 0; k < K / 32; k++) { | |
| uint64_t da = make_desc_noswz(sa + k * 256, LBO, SBO); | |
| uint64_t db = make_desc_noswz(sb + k * 256, LBO, SBO); | |
| asm volatile( | |
| "{\n\t" | |
| ".reg .pred p;\n\t" | |
| "setp.ne.b32 p, %4, 0;\n\t" | |
| "tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p, 0;\n\t" | |
| "}\n" | |
| :: "r"(0u), "l"(da), "l"(db), "r"(idhi), "r"(acc)); | |
| acc = 1; | |
| } | |
| } | |
| umma_commit(sm.bar); | |
| } | |
| if (warp == 0) sm.bar.wait(0); | |
| if (warp == 0 && elect_one_sync()) { t1 = clock64(); if (cycles) *cycles = t1 - t0; } | |
| __syncthreads(); | |
| { | |
| tcgen05_fence_after(); | |
| uint32_t cv[N / 2]; | |
| cute::SM100_TMEM_LOAD_32dp32b32x::copy(0, cv[0], cv[1], cv[2], cv[3], cv[4], cv[5], | |
| cv[6], cv[7], cv[8], cv[9], cv[10], cv[11], cv[12], cv[13], cv[14], cv[15], | |
| cv[16], cv[17], cv[18], cv[19], cv[20], cv[21], cv[22], cv[23], cv[24], cv[25], | |
| cv[26], cv[27], cv[28], cv[29], cv[30], cv[31]); | |
| cutlass::arch::fence_view_async_tmem_load(); | |
| tcgen05_fence_before(); | |
| int dp = warp * 32 + lane; | |
| int noff = (dp < 64) ? 0 : (N / 2); | |
| int m_ = (dp < 64) ? dp : dp - 64; | |
| if (m_ < M) | |
| for (int c = 0; c < N / 2; c++) | |
| out[m_ * N + c + noff] = __uint_as_float(cv[c]); | |
| } | |
| if (warp == 0) cute::TMEM::Allocator1Sm().free(0, 128); | |
| } | |
| // ===================================================================== | |
| // Host | |
| // ===================================================================== | |
| void cpu_ref(const e4m3* a, const e4m3* b, float* c) { | |
| for (int m = 0; m < M; m++) | |
| for (int n = 0; n < N; n++) { | |
| float s = 0; | |
| for (int k = 0; k < K; k++) | |
| s += float(a[m * K + k]) * float(b[n * K + k]); | |
| c[m * N + n] = s; | |
| } | |
| } | |
| bool check(const float* gpu, const float* ref, const char* name) { | |
| int errs = 0; float mx = 0; | |
| for (int i = 0; i < M * N; i++) { | |
| float e = fabsf(gpu[i] - ref[i]); | |
| mx = fmaxf(mx, e); | |
| if (e > 1.0f && errs++ < 3) | |
| printf(" MISMATCH [%d]: gpu=%.2f ref=%.2f\n", i, gpu[i], ref[i]); | |
| } | |
| printf("[%s] max_err=%.4f errors=%d/%d %s\n", name, mx, errs, M * N, errs ? "FAIL" : "PASS"); | |
| return errs == 0; | |
| } | |
| template <typename Kernel1, typename KernelN> | |
| bool run(const char* name, Kernel1 k1, KernelN kN, | |
| e4m3* da, e4m3* db, float* dc, e4m3* ha, e4m3* hb, float* hg, float* hr) { | |
| constexpr size_t ss = sizeof(SharedMem); | |
| cudaFuncSetAttribute(k1, cudaFuncAttributeMaxDynamicSharedMemorySize, ss); | |
| cudaFuncSetAttribute(kN, cudaFuncAttributeMaxDynamicSharedMemorySize, ss); | |
| int64_t *dcyc; cudaMalloc(&dcyc, 8); | |
| for (int i = 0; i < M * K; i++) ha[i] = e4m3(1.0f); | |
| for (int i = 0; i < N * K; i++) hb[i] = e4m3(1.0f); | |
| cudaMemcpy(da, ha, M * K, cudaMemcpyHostToDevice); | |
| cudaMemcpy(db, hb, N * K, cudaMemcpyHostToDevice); | |
| cudaMemset(dc, 0, M * N * 4); | |
| k1<<<1, 128, ss>>>(dc, da, db, nullptr); | |
| cudaDeviceSynchronize(); | |
| auto err = cudaGetLastError(); | |
| if (err != cudaSuccess) { printf("[%s] CUDA error: %s\n", name, cudaGetErrorString(err)); cudaFree(dcyc); return false; } | |
| cudaMemcpy(hg, dc, M * N * 4, cudaMemcpyDeviceToHost); | |
| cpu_ref(ha, hb, hr); | |
| char buf[64]; snprintf(buf, 64, "%s_ones", name); | |
| if (!check(hg, hr, buf)) { cudaFree(dcyc); return false; } | |
| for (int w = 0; w < 5; w++) kN<<<1, 128, ss>>>(dc, da, db, nullptr); | |
| cudaDeviceSynchronize(); | |
| kN<<<1, 128, ss>>>(dc, da, db, dcyc); | |
| cudaDeviceSynchronize(); | |
| int64_t hcyc; cudaMemcpy(&hcyc, dcyc, 8, cudaMemcpyDeviceToHost); | |
| printf("[%s] MMA-only: %ld cycles / %d iters = %ld cycles/iter\n", | |
| name, hcyc, BENCH_ITERS, hcyc / BENCH_ITERS); | |
| cudaFree(dcyc); | |
| return true; | |
| } | |
| int main() { | |
| e4m3 *ha = new e4m3[M * K], *hb = new e4m3[N * K]; | |
| float *hg = new float[M * N], *hr = new float[M * N]; | |
| e4m3 *da, *db; float *dc; | |
| cudaMalloc(&da, M * K); cudaMalloc(&db, N * K); cudaMalloc(&dc, M * N * 4); | |
| bool ok = true; | |
| printf("=== fp8 SS MMA: NoSwizzle vs SW32 vs SW64 vs SW128 ===\n"); | |
| printf("=== M=%d N=%d K=%d, %d iters, device-side clock64() ===\n\n", M, N, K, BENCH_ITERS); | |
| ok &= run("NoSwizzle", mma_noswizzle<1>, mma_noswizzle<BENCH_ITERS>, da, db, dc, ha, hb, hg, hr); | |
| ok &= run("SW32", mma_swizzle<Layout_SW32, 1>, mma_swizzle<Layout_SW32, BENCH_ITERS>, da, db, dc, ha, hb, hg, hr); | |
| ok &= run("SW64", mma_swizzle<Layout_SW64, 1>, mma_swizzle<Layout_SW64, BENCH_ITERS>, da, db, dc, ha, hb, hg, hr); | |
| ok &= run("SW128", mma_swizzle<Layout_SW128, 1>, mma_swizzle<Layout_SW128, BENCH_ITERS>, da, db, dc, ha, hb, hg, hr); | |
| printf("\n%s\n", ok ? "ALL PASSED" : "SOME FAILED"); | |
| delete[] ha; delete[] hb; delete[] hg; delete[] hr; | |
| cudaFree(da); cudaFree(db); cudaFree(dc); | |
| return ok ? 0 : 1; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment