Created
March 5, 2026 19:12
-
-
Save shunting314/0e804bc3bcf71c085351d606f0665876 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py | |
| index 0748643a5..bbab180ae 100644 | |
| --- a/vllm/compilation/cuda_graph.py | |
| +++ b/vllm/compilation/cuda_graph.py | |
| @@ -7,6 +7,7 @@ from collections.abc import Callable | |
| from contextlib import ExitStack | |
| from typing import Any | |
| from unittest.mock import patch | |
| +from vllm.forward_context import ForwardContext, get_forward_context | |
| import torch | |
| @@ -298,5 +299,12 @@ class CUDAGraphWrapper: | |
| f"got {new_input_addresses}" | |
| ) | |
| + if cudagraph_runtime_mode == CUDAGraphMode.FULL: | |
| + from vllm.v1.attention.backends.flex_attention import g_kv_indices, g_kv_num_blocks | |
| + forward_context = get_forward_context() | |
| + attn_metadata = forward_context.attn_metadata | |
| + m0 = attn_metadata["model.layers.0.self_attn.attn"] | |
| + m1 = attn_metadata["model.layers.1.self_attn.attn"] | |
| + # breakpoint() | |
| entry.cudagraph.replay() | |
| return entry.output | |
| diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py | |
| index 8193c05c2..b29364ffc 100644 | |
| --- a/vllm/v1/attention/backends/flex_attention.py | |
| +++ b/vllm/v1/attention/backends/flex_attention.py | |
| @@ -2,6 +2,7 @@ | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | |
| """Attention layer with FlexAttention.""" | |
| +import os | |
| import math | |
| from dataclasses import dataclass | |
| from functools import cached_property | |
| @@ -40,6 +41,17 @@ from vllm.v1.attention.backends.utils import ( | |
| CommonAttentionMetadata, | |
| ) | |
| from vllm.v1.kv_cache_interface import AttentionSpec | |
| +from vllm.v1.attention.backends.utils import ( | |
| + AttentionCGSupport, | |
| +) | |
| + | |
| +g_kv_indices = None | |
| +g_kv_num_blocks = None | |
| +g_decode_offset = None | |
| +g_doc_ids = None | |
| +# self.seq_lens is fine | |
| +# self.query_start_loc.shape is fine | |
| +g_physical_to_logical = None | |
| logger = init_logger(__name__) | |
| @@ -617,6 +629,30 @@ class FlexAttentionMetadata: | |
| ).to(torch.int32) | |
| kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) | |
| + | |
| + global g_kv_indices, g_kv_num_blocks | |
| + | |
| + if g_kv_indices is None: | |
| + M = 1024 | |
| + N = 1024 | |
| + g_kv_indices = torch.empty(M, N, dtype=torch.int32, device="cuda") | |
| + g_kv_num_blocks = torch.empty(M, dtype=torch.int32, device="cuda") | |
| + | |
| + if True: | |
| + assert kv_indices.size(0) <= g_kv_indices.size(0) | |
| + # if g_kv_indices.shape != kv_indices.shape or g_kv_num_blocks.shape != kv_num_blocks.shape: breakpoint() | |
| + if kv_indices.size(1) > g_kv_indices.size(1): | |
| + breakpoint() | |
| + new_siz0, new_siz1 = kv_indices.size() | |
| + g_kv_indices.fill_(-1) | |
| + g_kv_num_blocks.zero_() | |
| + g_kv_indices[:new_siz0, :new_siz1].copy_(kv_indices) | |
| + g_kv_num_blocks[:new_siz0].copy_(kv_num_blocks) | |
| + kv_indices = g_kv_indices[:new_siz0, :new_siz1] | |
| + kv_num_blocks = g_kv_num_blocks[:new_siz0] | |
| + | |
| + # print(f"{kv_indices.shape=} {kv_indices.data_ptr()} {kv_num_blocks.shape=}, {kv_num_blocks.data_ptr()}") | |
| + | |
| block_mask_kwargs = { | |
| "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens), | |
| "kv_num_blocks": kv_num_blocks[None, None], | |
| @@ -662,9 +698,35 @@ class FlexAttentionMetadata: | |
| self.block_mask = self._build_block_mask_direct() | |
| else: | |
| self.block_mask = self.build_block_mask() | |
| - | |
| + # self.physical_to_logical: torch.Size([512, 71999]) | |
| + global g_decode_offset, g_doc_ids, g_physical_to_logical | |
| + if g_decode_offset is None: | |
| + g_decode_offset = torch.empty(512, dtype=torch.int32, device="cuda") | |
| + g_doc_ids = torch.empty(512, dtype=torch.int32, device="cuda") | |
| + g_physical_to_logical = torch.empty(512, self.physical_to_logical.size(1), dtype=torch.int32, device="cuda") | |
| + | |
| + S = self.decode_offset.size(0) | |
| + assert S <= g_decode_offset.size(0) | |
| + g_decode_offset[:S].copy_(self.decode_offset) | |
| + self.decode_offset = g_decode_offset[:S] | |
| + | |
| + S = self.doc_ids.size(0) | |
| + assert S <= g_doc_ids.size(0) | |
| + g_doc_ids[:S].copy_(self.doc_ids) | |
| + self.doc_ids = g_doc_ids[:S] | |
| + | |
| + S1, S2= self.physical_to_logical.size() | |
| + assert S1 <= g_physical_to_logical.size(0) | |
| + assert S2 == g_physical_to_logical.size(1) | |
| + g_physical_to_logical[:S1, :].copy_(self.physical_to_logical) | |
| + self.physical_to_logical = g_physical_to_logical[:S1, :] | |
| + | |
| + # print(f"self.physical_to_logical shape {self.physical_to_logical.shape}, {self.physical_to_logical.data_ptr()}") | |
| class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): | |
| + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS | |
| + # _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE | |
| + | |
| def __init__( | |
| self, | |
| kv_cache_spec: AttentionSpec, | |
| @@ -759,6 +821,7 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat | |
| q_block_size=self.q_block_size, | |
| kv_block_size=self.kv_block_size, | |
| ) | |
| + # torch.distributed.breakpoint() # TODO | |
| return out | |
| def use_cascade_attention(self, *args, **kwargs) -> bool: | |
| @@ -942,6 +1005,7 @@ class FlexAttentionImpl(AttentionImpl): | |
| kernel_options = get_kernel_options( | |
| query, block_m, block_n, attn_metadata.direct_build | |
| ) | |
| + # print(f"{attn_metadata.transformed_score_mod=}, {attn_metadata.block_mask=}, {self.scale=}, {kernel_options=}") # TODO | |
| out = flex_attention_compiled( | |
| query, | |
| key_tensor, | |
| @@ -962,9 +1026,11 @@ class FlexAttentionImpl(AttentionImpl): | |
| def get_kernel_options( | |
| query, block_m, block_n, use_direct_build: bool | |
| ) -> dict[str, int | bool]: | |
| - kernel_options: dict[str, int | bool] = { | |
| - "FORCE_USE_FLEX_ATTENTION": True, | |
| - } | |
| + use_flex_decoding = os.getenv("USE_FLEX_DECODING") == "1" | |
| + kernel_options: dict[str, int | bool] = {} | |
| + | |
| + if not use_flex_decoding: | |
| + kernel_options["FORCE_USE_FLEX_ATTENTION"] = True | |
| def ensure_divisible(candidate: int, block_size: int) -> int: | |
| """Pick a kernel block size that divides the logical block.""" | |
| diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py | |
| index ca7be990c..2693ddd06 100644 | |
| --- a/vllm/v1/attention/backends/triton_attn.py | |
| +++ b/vllm/v1/attention/backends/triton_attn.py | |
| @@ -112,7 +112,7 @@ class TritonAttentionMetadata: | |
| class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): | |
| - _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS | |
| + # _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS | |
| def __init__( | |
| self, | |
| @@ -248,6 +248,7 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet | |
| softmax_segm_max=self.softmax_segm_max, | |
| softmax_segm_expsum=self.softmax_segm_expsum, | |
| ) | |
| + print(f"slot_mapping addr {slot_mapping.data_ptr()}") | |
| return attn_metadata | |
| diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py | |
| index 40c3e9a51..8d2e88f1c 100644 | |
| --- a/vllm/v1/engine/core.py | |
| +++ b/vllm/v1/engine/core.py | |
| @@ -112,8 +112,6 @@ class EngineCore: | |
| vllm_config | |
| ) | |
| - vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks | |
| - vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks | |
| self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) | |
| self.structured_output_manager = StructuredOutputManager(vllm_config) | |
| @@ -254,6 +252,9 @@ class EngineCore: | |
| num_gpu_blocks = scheduler_kv_cache_config.num_blocks | |
| num_cpu_blocks = 0 | |
| + vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks | |
| + vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks | |
| + | |
| # Initialize kv cache and warmup the execution | |
| self.model_executor.initialize_from_config(kv_cache_configs) | |
| diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py | |
| index 92822d829..bc18b9a26 100644 | |
| --- a/vllm/v1/worker/gpu_model_runner.py | |
| +++ b/vllm/v1/worker/gpu_model_runner.py | |
| @@ -19,6 +19,7 @@ import torch.distributed | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| +from vllm.attention.backends.registry import AttentionBackendEnum | |
| import vllm.envs as envs | |
| from vllm.attention.backends.abstract import ( | |
| AttentionBackend, | |
| @@ -1542,10 +1543,14 @@ class GPUModelRunner( | |
| if ubatch_slices is not None: | |
| attn_metadata = [dict() for _ in range(len(ubatch_slices))] | |
| - if for_cudagraph_capture: | |
| + if self.vllm_config.attention_config.backend != AttentionBackendEnum.FLEX_ATTENTION and for_cudagraph_capture: | |
| # For some attention backends (e.g. FA) with sliding window models we need | |
| # to make sure the backend see a max_seq_len that is larger to the sliding | |
| # window size when capturing to make sure the correct kernel is selected. | |
| + # | |
| + # For FlexAttention, warmup and capture with different max_seq_len | |
| + # will cause recompilation happening during graph capture. | |
| + # Need avoid that. | |
| max_seq_len = self.max_model_len | |
| else: | |
| max_seq_len = self.seq_lens.np[:num_reqs].max().item() | |
| @@ -1757,6 +1762,7 @@ class GPUModelRunner( | |
| spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) | |
| ) | |
| + # breakpoint() | |
| return attn_metadata, spec_decode_common_attn_metadata | |
| def _compute_cascade_attn_prefix_lens( |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment