Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save shunting314/0e804bc3bcf71c085351d606f0665876 to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/0e804bc3bcf71c085351d606f0665876 to your computer and use it in GitHub Desktop.
diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py
index 0748643a5..bbab180ae 100644
--- a/vllm/compilation/cuda_graph.py
+++ b/vllm/compilation/cuda_graph.py
@@ -7,6 +7,7 @@ from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
from unittest.mock import patch
+from vllm.forward_context import ForwardContext, get_forward_context
import torch
@@ -298,5 +299,12 @@ class CUDAGraphWrapper:
f"got {new_input_addresses}"
)
+ if cudagraph_runtime_mode == CUDAGraphMode.FULL:
+ from vllm.v1.attention.backends.flex_attention import g_kv_indices, g_kv_num_blocks
+ forward_context = get_forward_context()
+ attn_metadata = forward_context.attn_metadata
+ m0 = attn_metadata["model.layers.0.self_attn.attn"]
+ m1 = attn_metadata["model.layers.1.self_attn.attn"]
+ # breakpoint()
entry.cudagraph.replay()
return entry.output
diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py
index 8193c05c2..b29364ffc 100644
--- a/vllm/v1/attention/backends/flex_attention.py
+++ b/vllm/v1/attention/backends/flex_attention.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlexAttention."""
+import os
import math
from dataclasses import dataclass
from functools import cached_property
@@ -40,6 +41,17 @@ from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
+from vllm.v1.attention.backends.utils import (
+ AttentionCGSupport,
+)
+
+g_kv_indices = None
+g_kv_num_blocks = None
+g_decode_offset = None
+g_doc_ids = None
+# self.seq_lens is fine
+# self.query_start_loc.shape is fine
+g_physical_to_logical = None
logger = init_logger(__name__)
@@ -617,6 +629,30 @@ class FlexAttentionMetadata:
).to(torch.int32)
kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32)
+
+ global g_kv_indices, g_kv_num_blocks
+
+ if g_kv_indices is None:
+ M = 1024
+ N = 1024
+ g_kv_indices = torch.empty(M, N, dtype=torch.int32, device="cuda")
+ g_kv_num_blocks = torch.empty(M, dtype=torch.int32, device="cuda")
+
+ if True:
+ assert kv_indices.size(0) <= g_kv_indices.size(0)
+ # if g_kv_indices.shape != kv_indices.shape or g_kv_num_blocks.shape != kv_num_blocks.shape: breakpoint()
+ if kv_indices.size(1) > g_kv_indices.size(1):
+ breakpoint()
+ new_siz0, new_siz1 = kv_indices.size()
+ g_kv_indices.fill_(-1)
+ g_kv_num_blocks.zero_()
+ g_kv_indices[:new_siz0, :new_siz1].copy_(kv_indices)
+ g_kv_num_blocks[:new_siz0].copy_(kv_num_blocks)
+ kv_indices = g_kv_indices[:new_siz0, :new_siz1]
+ kv_num_blocks = g_kv_num_blocks[:new_siz0]
+
+ # print(f"{kv_indices.shape=} {kv_indices.data_ptr()} {kv_num_blocks.shape=}, {kv_num_blocks.data_ptr()}")
+
block_mask_kwargs = {
"seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
"kv_num_blocks": kv_num_blocks[None, None],
@@ -662,9 +698,35 @@ class FlexAttentionMetadata:
self.block_mask = self._build_block_mask_direct()
else:
self.block_mask = self.build_block_mask()
-
+ # self.physical_to_logical: torch.Size([512, 71999])
+ global g_decode_offset, g_doc_ids, g_physical_to_logical
+ if g_decode_offset is None:
+ g_decode_offset = torch.empty(512, dtype=torch.int32, device="cuda")
+ g_doc_ids = torch.empty(512, dtype=torch.int32, device="cuda")
+ g_physical_to_logical = torch.empty(512, self.physical_to_logical.size(1), dtype=torch.int32, device="cuda")
+
+ S = self.decode_offset.size(0)
+ assert S <= g_decode_offset.size(0)
+ g_decode_offset[:S].copy_(self.decode_offset)
+ self.decode_offset = g_decode_offset[:S]
+
+ S = self.doc_ids.size(0)
+ assert S <= g_doc_ids.size(0)
+ g_doc_ids[:S].copy_(self.doc_ids)
+ self.doc_ids = g_doc_ids[:S]
+
+ S1, S2= self.physical_to_logical.size()
+ assert S1 <= g_physical_to_logical.size(0)
+ assert S2 == g_physical_to_logical.size(1)
+ g_physical_to_logical[:S1, :].copy_(self.physical_to_logical)
+ self.physical_to_logical = g_physical_to_logical[:S1, :]
+
+ # print(f"self.physical_to_logical shape {self.physical_to_logical.shape}, {self.physical_to_logical.data_ptr()}")
class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]):
+ _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
+ # _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
+
def __init__(
self,
kv_cache_spec: AttentionSpec,
@@ -759,6 +821,7 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
q_block_size=self.q_block_size,
kv_block_size=self.kv_block_size,
)
+ # torch.distributed.breakpoint() # TODO
return out
def use_cascade_attention(self, *args, **kwargs) -> bool:
@@ -942,6 +1005,7 @@ class FlexAttentionImpl(AttentionImpl):
kernel_options = get_kernel_options(
query, block_m, block_n, attn_metadata.direct_build
)
+ # print(f"{attn_metadata.transformed_score_mod=}, {attn_metadata.block_mask=}, {self.scale=}, {kernel_options=}") # TODO
out = flex_attention_compiled(
query,
key_tensor,
@@ -962,9 +1026,11 @@ class FlexAttentionImpl(AttentionImpl):
def get_kernel_options(
query, block_m, block_n, use_direct_build: bool
) -> dict[str, int | bool]:
- kernel_options: dict[str, int | bool] = {
- "FORCE_USE_FLEX_ATTENTION": True,
- }
+ use_flex_decoding = os.getenv("USE_FLEX_DECODING") == "1"
+ kernel_options: dict[str, int | bool] = {}
+
+ if not use_flex_decoding:
+ kernel_options["FORCE_USE_FLEX_ATTENTION"] = True
def ensure_divisible(candidate: int, block_size: int) -> int:
"""Pick a kernel block size that divides the logical block."""
diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py
index ca7be990c..2693ddd06 100644
--- a/vllm/v1/attention/backends/triton_attn.py
+++ b/vllm/v1/attention/backends/triton_attn.py
@@ -112,7 +112,7 @@ class TritonAttentionMetadata:
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
- _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
+ # _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(
self,
@@ -248,6 +248,7 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
softmax_segm_max=self.softmax_segm_max,
softmax_segm_expsum=self.softmax_segm_expsum,
)
+ print(f"slot_mapping addr {slot_mapping.data_ptr()}")
return attn_metadata
diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py
index 40c3e9a51..8d2e88f1c 100644
--- a/vllm/v1/engine/core.py
+++ b/vllm/v1/engine/core.py
@@ -112,8 +112,6 @@ class EngineCore:
vllm_config
)
- vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
- vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
self.structured_output_manager = StructuredOutputManager(vllm_config)
@@ -254,6 +252,9 @@ class EngineCore:
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
num_cpu_blocks = 0
+ vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
+ vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
+
# Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs)
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 92822d829..bc18b9a26 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -19,6 +19,7 @@ import torch.distributed
import torch.nn as nn
from tqdm import tqdm
+from vllm.attention.backends.registry import AttentionBackendEnum
import vllm.envs as envs
from vllm.attention.backends.abstract import (
AttentionBackend,
@@ -1542,10 +1543,14 @@ class GPUModelRunner(
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
- if for_cudagraph_capture:
+ if self.vllm_config.attention_config.backend != AttentionBackendEnum.FLEX_ATTENTION and for_cudagraph_capture:
# For some attention backends (e.g. FA) with sliding window models we need
# to make sure the backend see a max_seq_len that is larger to the sliding
# window size when capturing to make sure the correct kernel is selected.
+ #
+ # For FlexAttention, warmup and capture with different max_seq_len
+ # will cause recompilation happening during graph capture.
+ # Need avoid that.
max_seq_len = self.max_model_len
else:
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
@@ -1757,6 +1762,7 @@ class GPUModelRunner(
spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
)
+ # breakpoint()
return attn_metadata, spec_decode_common_attn_metadata
def _compute_cascade_attn_prefix_lens(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment