Last active
March 8, 2026 21:17
-
-
Save LunNova/0fb4e233edef3d2906755ab98db9c1f8 to your computer and use it in GitHub Desktop.
FYI if you're reading this gist 906 is merged into upstream triton as of march 2026.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir | |
| --- a/test/Conversion/amd/tritongpu_to_llvm.mlir | |
| +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir | |
| @@ -1,5 +1,6 @@ | |
| // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s | |
| // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950 | |
| +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx906 | FileCheck %s --check-prefix=GFX906 | |
| module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { | |
| // CHECK-LABEL: atomic_add_f32_scalar | |
| @@ -633,3 +634,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr | |
| tt.return | |
| } | |
| } | |
| + | |
| +// ----- | |
| + | |
| +// GFX906-LABEL: v_dot_fp16_gfx906 | |
| +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> | |
| +module attributes {"ttg.target" = "hip:gfx906", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { | |
| + tt.func @v_dot_fp16_gfx906(%arg0: tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<16x16xf32, #blocked>) { | |
| + // GFX906-COUNT-8: llvm.call_intrinsic "llvm.amdgcn.fdot2" | |
| + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf32, #blocked> | |
| + tt.return | |
| + } | |
| +} | |
| + | |
| +// ----- | |
| + | |
| +// GFX906-LABEL: v_dot_i8_gfx906 | |
| +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0]}> | |
| +module attributes {"ttg.target" = "hip:gfx906", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { | |
| + tt.func @v_dot_i8_gfx906(%arg0: tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, %arg1: tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %arg2: tensor<16x16xi32, #blocked>) { | |
| + // GFX906-COUNT-4: llvm.call_intrinsic "llvm.amdgcn.sdot4" | |
| + %0 = tt.dot %arg0, %arg1, %arg2, inputPrecision = ieee : tensor<16x16xi8, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xi8, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xi32, #blocked> | |
| + tt.return | |
| + } | |
| +} | |
| diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h | |
| --- a/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h | |
| +++ b/third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h | |
| @@ -8,6 +8,7 @@ namespace mlir::triton::AMD { | |
| // A list of ISA families we care about. | |
| enum class ISAFamily { | |
| Unknown, | |
| + GFX906, | |
| CDNA1, | |
| CDNA2, | |
| CDNA3, | |
| diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | |
| --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | |
| +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | |
| @@ -72,6 +72,7 @@ llvm::AMDGPU::GPUKind TargetInfo::getGPUKind() const { | |
| int TargetInfo::getWarpSize() const { | |
| switch (getISAFamily()) { | |
| + case ISAFamily::GFX906: | |
| case ISAFamily::CDNA1: | |
| case ISAFamily::CDNA2: | |
| case ISAFamily::CDNA3: | |
| @@ -395,10 +396,10 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, | |
| return true; | |
| if (reduceLaneIdMask != (getWarpSize() - 1)) | |
| return false; | |
| - if (isCDNA(getISAFamily()) && getISAFamily() == ISAFamily::CDNA1) | |
| - return false; | |
| - if (isRDNA(getISAFamily()) && | |
| - llvm::is_contained({ISAFamily::RDNA1, ISAFamily::RDNA2}, getISAFamily())) | |
| + // DPP warp reduce requires gfx90a+ (CDNA2+) or gfx11+ (RDNA3+). | |
| + // Pre-CDNA2 GFX9 (gfx906/gfx908) and GFX10 (RDNA1/2) are excluded. | |
| + auto v = getIsaVersion(); | |
| + if (!((v.Major == 9 && (v.Minor > 0 || v.Stepping >= 0xa)) || v.Major >= 11)) | |
| return false; | |
| Operation *reduxOp = op.getSingleCombiner(); | |
| @@ -464,7 +465,7 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, | |
| buf = createDppReduxOpWithBoundCtrl(valType, buf, 1 + dppCtrlRowShr, | |
| allRows, allBanks); | |
| - if (isCDNA(getISAFamily())) { | |
| + if (isCDNA(getISAFamily()) || getISAFamily() == ISAFamily::GFX906) { | |
| // row_bcast:15 row_mask:0xa | |
| buf = createDppReduxOpWithBoundCtrl( | |
| valType, buf, static_cast<uint32_t>(DppCtrl::BCAST15), 0xa, allBanks); | |
| diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp | |
| --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp | |
| +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp | |
| @@ -12,6 +12,9 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) { | |
| if (kind == llvm::AMDGPU::GK_GFX1250) | |
| return ISAFamily::GFX1250; | |
| + if (kind == llvm::AMDGPU::GK_GFX906) | |
| + return ISAFamily::GFX906; | |
| + | |
| // CDNA ISA cases | |
| switch (kind) { | |
| case llvm::AMDGPU::GK_GFX950: | |
| @@ -41,6 +44,7 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) { | |
| bool supportsVDot(llvm::StringRef arch) { | |
| switch (deduceISAFamily(arch)) { | |
| + case AMD::ISAFamily::GFX906: | |
| case AMD::ISAFamily::CDNA1: | |
| case AMD::ISAFamily::CDNA2: | |
| case AMD::ISAFamily::CDNA3: |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment