Last active
March 1, 2026 18:59
-
-
Save LunNova/f968233ae3a73916b488304a74c20722 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu | |
| index a6cf63f..3ae2583 100644 | |
| --- a/csrc/rocm/skinny_gemms.cu | |
| +++ b/csrc/rocm/skinny_gemms.cu | |
| @@ -21,7 +21,8 @@ | |
| // However, it may be possible to fix these kernels to handle both issues. | |
| #if defined(__HIPCC__) && \ | |
| - (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) | |
| + (defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \ | |
| + defined(__gfx942__) || defined(__gfx950__)) | |
| #define __HIP__GFX9__ | |
| #endif | |
| @@ -285,9 +286,19 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, | |
| return out_c; | |
| } | |
| +// gfx906 has v_dot2_f32_f16 (VOP3P, 3-src: dst = dot2(s0,s1) + s2) | |
| +// gfx908+ has v_dot2c_f32_f16 (VOP2, compact: dst += dot2(s0,s1)) | |
| +#if defined(__gfx906__) | |
| +#define DOT2C_F16_ASM(V0, V2, V3) \ | |
| + asm("v_dot2_f32_f16 %0, %2, %3, %1" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); | |
| +#else | |
| +#define DOT2C_F16_ASM(V0, V2, V3) \ | |
| + asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); | |
| +#endif | |
| + | |
| #define DOT2C(V0, V2, V3) \ | |
| if constexpr (std::is_same_v<scalar_t, half>) { \ | |
| - asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \ | |
| + DOT2C_F16_ASM(V0, V2, V3) \ | |
| } else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \ | |
| float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \ | |
| __bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \ | |
| diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py | |
| index 6e67456..321386b 100644 | |
| --- a/vllm/platforms/rocm.py | |
| +++ b/vllm/platforms/rocm.py | |
| @@ -102,7 +102,8 @@ def on_mi3xx() -> bool: | |
| @cache | |
| def on_gfx9() -> bool: | |
| GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName | |
| - return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) | |
| + return any(arch in GPU_ARCH for arch in ["gfx906", "gfx908", "gfx90a", | |
| + "gfx942", "gfx950"]) | |
| @cache | |
| @@ -165,7 +166,7 @@ def use_rocm_custom_paged_attention( | |
| @cache | |
| def flash_attn_triton_available() -> bool: | |
| - if not on_gfx1x(): | |
| + if not (on_gfx1x() or on_gfx9()): | |
| return False | |
| try: | |
| from importlib.util import find_spec | |
| @@ -274,6 +275,10 @@ class RocmPlatform(Platform): | |
| f"is not MLA type while requested for MLA backend." | |
| ) | |
| + if selected_backend == AttentionBackendEnum.FLASH_ATTN: | |
| + logger.info("Using Flash Attention backend.") | |
| + return AttentionBackendEnum.FLASH_ATTN.get_path() | |
| + | |
| if selected_backend == AttentionBackendEnum.FLEX_ATTENTION: | |
| logger.info("Using FlexAttention backend.") | |
| return AttentionBackendEnum.FLEX_ATTENTION.get_path() | |
| diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py | |
| index 988cf7c..6ee7c14 100644 | |
| --- a/vllm/v1/attention/backends/fa_utils.py | |
| +++ b/vllm/v1/attention/backends/fa_utils.py | |
| @@ -23,6 +23,8 @@ elif current_platform.is_xpu(): | |
| get_scheduler_metadata = ipex_ops.get_scheduler_metadata # type: ignore[assignment] | |
| elif current_platform.is_rocm(): | |
| + from vllm._custom_ops import reshape_and_cache_flash # type: ignore[no-redef] | |
| + | |
| try: | |
| from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] | |
| except ImportError: | |
| @@ -33,6 +35,9 @@ elif current_platform.is_rocm(): | |
| "to be installed. Please install flash-attn first." | |
| ) | |
| + def get_scheduler_metadata(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc] | |
| + raise NotImplementedError("get_scheduler_metadata is not supported on ROCm") | |
| + | |
| def get_flash_attn_version(requires_alibi: bool = False) -> int | None: | |
| # import here to avoid circular dependencies | |
| @@ -127,4 +132,5 @@ def flash_attn_supports_mla(): | |
| def is_flash_attn_varlen_func_available() -> bool: | |
| - return current_platform.is_cuda() or current_platform.is_xpu() | |
| + return (current_platform.is_cuda() or current_platform.is_xpu() | |
| + or current_platform.is_rocm()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # TRITON_ATTN is the only one I have happy on gfx906 yet. | |
| nix develop -c env HSA_OVERRIDE_GFX_VERSION=9.0.6 vllm serve Qwen/Qwen3-VL-4B-Instruct --max-model-len 4096 --gpu-memory-utilization 0.95 --limit-mm-per-prompt '{"image": 0, "video": 0}' --attention-backend TRITON_ATTN |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Odd, I still see a failure after applying this patch, though I have not looked at this code path or trace myself. Here is the log:
Error
Edit: Looks like it also needs to be added here: https://github.com/vllm-project/vllm/blob/1892993bc18e243e2c05841314c5e9c06a80c70d/vllm/platforms/rocm.py#L229