Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created December 3, 2025 01:16
Show Gist options
  • Select an option

  • Save shunting314/faa70cfa93ee5da4b6996c5d7b523982 to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/faa70cfa93ee5da4b6996c5d7b523982 to your computer and use it in GitHub Desktop.
diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py
index b4f175183..b539fd1f8 100644
--- a/vllm/benchmarks/latency.py
+++ b/vllm/benchmarks/latency.py
@@ -101,7 +101,7 @@ def main(args: argparse.Namespace):
sampling_params = SamplingParams(
n=args.n,
- temperature=1.0,
+ temperature=0.0,
top_p=1.0,
ignore_eos=True,
max_tokens=args.output_len,
diff --git a/vllm/config/model.py b/vllm/config/model.py
index 49688e17c..18445845b 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -6,6 +6,7 @@ from collections.abc import Callable
from dataclasses import InitVar, field
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
+import functools
import torch
from pydantic import ConfigDict, SkipValidation, field_validator, model_validator
@@ -314,6 +315,8 @@ class ModelConfig:
skip_mm_profiling: InitVar[bool | None] = None
video_pruning_rate: InitVar[float | None] = None
+ _is_encoder_decoder = None
+
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@@ -1576,7 +1579,9 @@ class ModelConfig:
@property
def is_encoder_decoder(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
- return is_encoder_decoder(self.hf_config)
+ if self._is_encoder_decoder is None:
+ self._is_encoder_decoder = is_encoder_decoder(self.hf_config)
+ return self._is_encoder_decoder
@property
def uses_alibi(self) -> bool:
diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py
index b6078706d..e5585800a 100644
--- a/vllm/config/scheduler.py
+++ b/vllm/config/scheduler.py
@@ -141,6 +141,8 @@ class SchedulerConfig:
while a larger value (e.g., 10) reduces host overhead and may increase throughput
by batching multiple tokens before sending."""
+ mega_step_size = 16
+
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
if self.scheduler_cls is None:
if self.async_scheduling:
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 848916dbd..0e1c29bca 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -1718,6 +1718,8 @@ class LLM:
def _run_engine(
self, *, use_tqdm: bool | Callable[..., tqdm] = True
) -> list[RequestOutput | PoolingRequestOutput]:
+ import vllm.v1.core.sched.scheduler
+ vllm.v1.core.sched.scheduler.seen_mega_step = False
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
@@ -1764,4 +1766,15 @@ class LLM:
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
+ if True:
+ # compute token hash
+ myhash = 0
+ for output in outputs:
+ assert len(output.outputs) == 1
+ out = output.outputs[0]
+ for tok in out.token_ids:
+ myhash = (myhash * 23 + tok) % 1000000007
+ print(f"myhash is {myhash}")
+ # breakpoint()
+
return sorted(outputs, key=lambda x: int(x.request_id))
diff --git a/vllm/envs.py b/vllm/envs.py
index 9b1ed1fc6..0d171581e 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -1527,6 +1527,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
),
+ "VLLM_MEGA_STEP": lambda: bool(
+ int(os.getenv("VLLM_MEGA_STEP", "1"))
+ ),
}
# --8<-- [end:env-vars-definition]
diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py
index e3f499216..71d975a0c 100755
--- a/vllm/v1/attention/backends/flashinfer.py
+++ b/vllm/v1/attention/backends/flashinfer.py
@@ -659,6 +659,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
+ follow_mega=False,
) -> FlashInferMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -698,6 +699,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
seq_lens_np = seq_lens_cpu.numpy()
+ # XXX nop if follow_mega
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
use_cascade = common_prefix_len > 0
@@ -728,34 +730,42 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
shared_kv_last_page_len_cpu = None
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
- np.cumsum(
- num_blocks_np,
- dtype=np.int32,
- out=self.paged_kv_indptr_np[1 : num_reqs + 1],
- )
+ if not follow_mega:
+ np.cumsum(
+ num_blocks_np,
+ dtype=np.int32,
+ out=self.paged_kv_indptr_np[1 : num_reqs + 1],
+ )
+
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# after this line (e.g., for cuda graphs), we need to copy the data to
# self.paged_kv_indptr_buffer to avoid race condition.
- self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[
- : num_reqs + 1
- ]
+ if not follow_mega:
+ self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[
+ : num_reqs + 1
+ ]
paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1]
- paged_kv_indptr.copy_(
- self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True
- )
+ # XXX nop if follow mega
+ if not follow_mega:
+ paged_kv_indptr.copy_(
+ self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True
+ )
# write self.paged_kv_indices inplace
num_actual_pages = self.paged_kv_indptr_np[num_reqs]
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
- _copy_page_indices_kernel[(num_reqs,)](
- paged_kv_indices,
- block_table_tensor,
- block_table_tensor.stride(0),
- paged_kv_indptr,
- BLOCK_SIZE=1024,
- )
+ # XXX nop if follow mega
+ if not follow_mega:
+ _copy_page_indices_kernel[(num_reqs,)](
+ paged_kv_indices,
+ block_table_tensor,
+ block_table_tensor.stride(0),
+ paged_kv_indptr,
+ BLOCK_SIZE=1024,
+ )
# write self.paged_kv_last_page_len_cpu inplace
+ # XXX plus 1 for follow mega
paged_kv_last_page_len_np = seq_lens_np % page_size
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py
index 7902513dc..7d3fcb6e2 100644
--- a/vllm/v1/core/sched/output.py
+++ b/vllm/v1/core/sched/output.py
@@ -188,6 +188,8 @@ class SchedulerOutput:
# list of mm_hash strings associated with the encoder outputs to be
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
+ is_mega_step: bool
+ num_computed_tokens_offset: int
# Request IDs that are preempted in this step.
# Only used for v2 model runner.
diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py
index 4cb5348cb..49b534332 100644
--- a/vllm/v1/core/sched/scheduler.py
+++ b/vllm/v1/core/sched/scheduler.py
@@ -49,6 +49,7 @@ from vllm.v1.utils import record_function_or_nullcontext
logger = init_logger(__name__)
+seen_mega_step = False
class Scheduler(SchedulerInterface):
def __init__(
@@ -191,6 +192,8 @@ class Scheduler(SchedulerInterface):
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def schedule(self) -> SchedulerOutput:
+ start_ts = time.time()
+ global seen_mega_step
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
@@ -221,8 +224,11 @@ class Scheduler(SchedulerInterface):
# First, schedule the RUNNING requests.
req_index = 0
+ is_mega_step = True
+ num_computed_tokens_offset = -1
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
+ # if req_index == 0: print(f"schedule, req0 {request.num_computed_tokens=}")
num_new_tokens = (
request.num_tokens_with_spec
@@ -279,6 +285,9 @@ class Scheduler(SchedulerInterface):
# Schedule newly needed KV blocks for the request.
with record_function_or_nullcontext("schedule: allocate_slots"):
while True:
+ if seen_mega_step and request.num_computed_tokens % self.vllm_config.cache_config.block_size != 0:
+ new_blocks = self.kv_cache_manager.empty_kv_cache_blocks
+ break
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
@@ -343,6 +352,10 @@ class Scheduler(SchedulerInterface):
break
# Schedule the request.
+ if request.num_computed_tokens % self.scheduler_config.mega_step_size != 0 or num_new_tokens != 1:
+ num_computed_tokens_offset = request.num_computed_tokens % self.scheduler_config.mega_step_size
+ is_mega_step = False
+
scheduled_running_reqs.append(request)
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
@@ -592,6 +605,10 @@ class Scheduler(SchedulerInterface):
req_index += 1
self.running.append(request)
+
+ if request.num_computed_tokens % self.scheduler_config.mega_step_size != 0 or num_new_tokens != 1:
+ num_computed_tokens_offset = request.num_computed_tokens % self.scheduler_config.mega_step_size
+ is_mega_step = False
if self.log_stats:
request.record_event(
EngineCoreEventType.SCHEDULED, scheduled_timestamp
@@ -692,6 +709,9 @@ class Scheduler(SchedulerInterface):
self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
+ if is_mega_step:
+ seen_mega_step = True
+
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
@@ -707,6 +727,8 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
+ is_mega_step=is_mega_step,
+ num_computed_tokens_offset=num_computed_tokens_offset,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
@@ -728,6 +750,10 @@ class Scheduler(SchedulerInterface):
with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output)
+
+ end_ts = time.time()
+ elapsed_ms = (end_ts - start_ts) * 1000
+ # print(f"schedule {elapsed_ms:.3f} ms")
return scheduler_output
def _update_after_schedule(
@@ -746,6 +772,11 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id]
+ if False:
+ if envs.VLLM_MEGA_STEP and scheduler_output.is_mega_step:
+ request.num_computed_tokens += self.scheduler_config.mega_step_size
+ else:
+ request.num_computed_tokens += num_scheduled_token
request.num_computed_tokens += num_scheduled_token
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
@@ -1028,6 +1059,10 @@ class Scheduler(SchedulerInterface):
# in pipeline parallelism).
continue
+ if envs.VLLM_MEGA_STEP and scheduler_output.is_mega_step:
+ request.num_computed_tokens += self.scheduler_config.mega_step_size - 1
+
+
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = (
sampled_token_ids[req_index] if sampled_token_ids else []
diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py
index 8657a95b5..2f0cf147d 100644
--- a/vllm/v1/engine/core.py
+++ b/vllm/v1/engine/core.py
@@ -16,6 +16,7 @@ from typing import Any, TypeVar, cast
import msgspec
import zmq
+from vllm import envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.envs import enable_envs_cache
@@ -327,6 +328,12 @@ class EngineCore:
return callback
+ def mega_step(self, scheduler_output):
+ # run execute_model
+ from vllm.v1.worker.mega_step_runner import MegaStepRunner
+ runner = MegaStepRunner(self, scheduler_output)
+ return runner.run()
+
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
@@ -338,17 +345,22 @@ class EngineCore:
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
+
scheduler_output = self.scheduler.schedule()
- future = self.model_executor.execute_model(scheduler_output, non_block=True)
- grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
- with self.log_error_detail(scheduler_output):
- model_output = future.result()
- if model_output is None:
- model_output = self.model_executor.sample_tokens(grammar_output)
+ if envs.VLLM_MEGA_STEP and scheduler_output.is_mega_step:
+ model_output = self.mega_step(scheduler_output)
+ else:
+ future = self.model_executor.execute_model(scheduler_output, non_block=True)
+ grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
+ with self.log_error_detail(scheduler_output):
+ model_output = future.result()
+ if model_output is None:
+ model_output = self.model_executor.sample_tokens(grammar_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
+ # print(f"EngineCore.step: {model_output.sampled_token_ids=}") # TODO
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py
index e403cea87..605964f30 100644
--- a/vllm/v1/engine/llm_engine.py
+++ b/vllm/v1/engine/llm_engine.py
@@ -292,21 +292,22 @@ class LLMEngine:
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
- self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
+ # self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 3) Abort any reqs that finished due to stop strings.
with record_function_or_nullcontext("llm_engine step: abort_requests"):
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats
- with record_function_or_nullcontext("llm_engine step: record_stats"):
- if self.logger_manager is not None and outputs.scheduler_stats is not None:
- self.logger_manager.record(
- scheduler_stats=outputs.scheduler_stats,
- iteration_stats=iteration_stats,
- mm_cache_stats=self.processor.stat_mm_cache(),
- )
- self.do_log_stats_with_interval()
+ if False:
+ with record_function_or_nullcontext("llm_engine step: record_stats"):
+ if self.logger_manager is not None and outputs.scheduler_stats is not None:
+ self.logger_manager.record(
+ scheduler_stats=outputs.scheduler_stats,
+ iteration_stats=iteration_stats,
+ mm_cache_stats=self.processor.stat_mm_cache(),
+ )
+ self.do_log_stats_with_interval()
return processed_outputs.request_outputs
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index e786cd8bc..98bec2171 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -580,6 +580,9 @@ class GPUModelRunner(
self.execute_model_state: ExecuteModelState | None = None
self.kv_connector_output: KVConnectorOutput | None = None
+ self.saved_num_sampled_tokens = None
+ self.saved_logits_indices = None
+
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
@@ -684,6 +687,7 @@ class GPUModelRunner(
The SamplingMetadata is updated and copied to the GPU if there is a
new/resumed/paused/finished request in the batch.
"""
+ start_ts = time.time()
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
@@ -769,6 +773,8 @@ class GPUModelRunner(
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count = self._get_valid_sampled_token_count()
+ from vllm.v1.core.sched.scheduler import seen_mega_step
+ follow_mega = seen_mega_step and not scheduler_output.is_mega_step
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
@@ -776,6 +782,10 @@ class GPUModelRunner(
resumed_from_preemption = req_id in req_data.resumed_req_ids
num_output_tokens = req_data.num_output_tokens[i]
req_index = self.input_batch.req_id_to_index.get(req_id)
+ self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
+
+ if follow_mega:
+ continue
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
@@ -859,7 +869,6 @@ class GPUModelRunner(
continue
# Update the persistent batch.
- self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
if new_block_ids is not None:
self.input_batch.block_table.append_row(new_block_ids, req_index)
@@ -914,13 +923,18 @@ class GPUModelRunner(
for request in reqs_to_add:
self.input_batch.add_request(request)
- # Condense the batched states if there are gaps left by removed requests
- self.input_batch.condense()
- # Allow attention backend to reorder the batch, potentially
- self._may_reorder_batch(scheduler_output)
+ if not follow_mega:
+ # Condense the batched states if there are gaps left by removed requests
+ self.input_batch.condense()
+ # Allow attention backend to reorder the batch, potentially
+ self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
+ end_ts = time.time()
+ elapsed_ms = (end_ts - start_ts) * 1000
+ # print(f"_update_states elapsed {elapsed_ms:.3f} ms")
+
def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor
) -> None:
@@ -1188,6 +1202,8 @@ class GPUModelRunner(
ubatch_slices, num_tokens_across_dp,
]
"""
+ start_ts = time.time()
+ from vllm.v1.core.sched.scheduler import seen_mega_step
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@@ -1195,23 +1211,31 @@ class GPUModelRunner(
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
- self.input_batch.block_table.commit_block_table(num_reqs)
-
- # Get request indices.
- # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
- req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
-
- # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
- # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
- cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
-
- # Get positions.
- positions_np = self.positions.np[:total_num_scheduled_tokens]
- np.add(
- self.input_batch.num_computed_tokens_cpu[req_indices],
- arange,
- out=positions_np,
- )
+ follow_mega = seen_mega_step and not scheduler_output.is_mega_step
+ if follow_mega:
+ # assert torch.equal(self.input_batch.block_table[0].block_table.cpu[:num_reqs].cuda(), self.input_batch.block_table[0].block_table.gpu[:num_reqs])
+ # print("Can skip commiting block table")
+ pass
+ else:
+ self.input_batch.block_table.commit_block_table(num_reqs)
+
+ if not follow_mega:
+ # Get request indices.
+ # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
+ req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
+
+ # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
+ # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
+ cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
+
+ # Get positions.
+ # XXX following step can just use base version + offset
+ positions_np = self.positions.np[:total_num_scheduled_tokens]
+ np.add(
+ self.input_batch.num_computed_tokens_cpu[req_indices],
+ arange,
+ out=positions_np,
+ )
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@@ -1222,20 +1246,22 @@ class GPUModelRunner(
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
- token_indices = (
- positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
- )
- token_indices_tensor = torch.from_numpy(token_indices)
-
- # NOTE(woosuk): We use torch.index_select instead of np.take here
- # because torch.index_select is much faster than np.take for large
- # tensors.
- torch.index_select(
- self.input_batch.token_ids_cpu_tensor.flatten(),
- 0,
- token_indices_tensor,
- out=self.input_ids.cpu[:total_num_scheduled_tokens],
- )
+ if not follow_mega:
+ token_indices = (
+ positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
+ )
+ token_indices_tensor = torch.from_numpy(token_indices)
+
+ # NOTE(woosuk): We use torch.index_select instead of np.take here
+ # because torch.index_select is much faster than np.take for large
+ # tensors.
+ torch.index_select(
+ self.input_batch.token_ids_cpu_tensor.flatten(),
+ 0,
+ token_indices_tensor,
+ out=self.input_ids.cpu[:total_num_scheduled_tokens],
+ )
+
if self.enable_prompt_embeds:
is_token_ids = self.input_batch.is_token_ids_tensor.flatten()
torch.index_select(
@@ -1283,23 +1309,34 @@ class GPUModelRunner(
output_idx += num_sched
- self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
- self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
+ # XXX following step can just use base version + offset
+ if not follow_mega:
+ self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
+ self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
+ else:
+ # dif = self.input_batch.block_table[0].slot_mapping.cpu[:total_num_scheduled_tokens].cuda() - self.input_batch.block_table[0].slot_mapping.gpu[:total_num_scheduled_tokens]
+ # assert torch.all(dif == scheduler_output.num_computed_tokens_offset).item()
+ self.input_batch.block_table[0].slot_mapping.gpu[:total_num_scheduled_tokens] += 1
+
+ if not follow_mega:
+ # Prepare the attention metadata.
+ self.query_start_loc.np[0] = 0
+ self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
+ # Note: pad query_start_loc to be non-decreasing, as kernels
+ # like FlashAttention requires that
+ self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
+ self.query_start_loc.copy_to_gpu()
- # Prepare the attention metadata.
- self.query_start_loc.np[0] = 0
- self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
- # Note: pad query_start_loc to be non-decreasing, as kernels
- # like FlashAttention requires that
- self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
- self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded)
- uniform_decode = (
- max_num_scheduled_tokens == self.uniform_decode_query_len
- ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
+ if not follow_mega:
+ uniform_decode = (
+ max_num_scheduled_tokens == self.uniform_decode_query_len
+ ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
+ else:
+ uniform_decode = True
# Disable DP padding when running eager to avoid excessive padding when
# running prefills. This lets us set enforce_eager on the prefiller in
@@ -1317,33 +1354,41 @@ class GPUModelRunner(
num_scheduled_tokens_per_request=num_scheduled_tokens,
)
- self.seq_lens.np[:num_reqs] = (
- self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
- )
- # Fill unused with 0 for full cuda graph mode.
- self.seq_lens.np[num_reqs:].fill(0)
- self.seq_lens.copy_to_gpu()
+ if follow_mega:
+ self.seq_lens.gpu[:num_reqs] += 1
+ else:
+ self.seq_lens.np[:num_reqs] = (
+ self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
+ )
+ # Fill unused with 0 for full cuda graph mode.
+ self.seq_lens.np[num_reqs:].fill(0)
+ self.seq_lens.copy_to_gpu()
- num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
- num_tokens_np = np.array(num_tokens, dtype=np.int32)
# Record the index of requests that should not be sampled,
# so that we could clear the sampled tokens before returning
- discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
- discard_request_indices = np.nonzero(discard_requests_mask)[0]
- self.num_discarded_requests = len(discard_request_indices)
- self.discard_request_indices.np[: self.num_discarded_requests] = (
- discard_request_indices
- )
-
- self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
+ if not follow_mega:
+ num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
+ num_tokens_np = np.array(num_tokens, dtype=np.int32)
+ discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
+ discard_request_indices = np.nonzero(discard_requests_mask)[0]
+ self.num_discarded_requests = len(discard_request_indices)
+ self.discard_request_indices.np[: self.num_discarded_requests] = (
+ discard_request_indices
+ )
+
+ self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
# Copy the tensors to the GPU.
- self._prepare_input_ids(
- scheduler_output,
- total_num_scheduled_tokens,
- cu_num_tokens,
- )
+ if False:
+ self._prepare_input_ids(
+ scheduler_output,
+ total_num_scheduled_tokens,
+ cu_num_tokens,
+ )
+
+ if not follow_mega:
+ self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@@ -1353,7 +1398,10 @@ class GPUModelRunner(
)
else:
# Common case (1D positions)
- self.positions.copy_to_gpu(total_num_scheduled_tokens)
+ if follow_mega:
+ self.positions.gpu[:total_num_scheduled_tokens] += 1
+ else:
+ self.positions.copy_to_gpu(total_num_scheduled_tokens)
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
@@ -1362,10 +1410,14 @@ class GPUModelRunner(
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
- logits_indices = query_start_loc[1:] - 1
+ if not follow_mega:
+ self.saved_logits_indices = logits_indices = query_start_loc[1:] - 1
+ self.saved_num_sampled_tokens = num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
+ else:
+ logits_indices = self.saved_logits_indices
+ num_sampled_tokens = self.saved_num_sampled_tokens
num_draft_tokens = None
spec_decode_metadata = None
- num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
@@ -1408,6 +1460,9 @@ class GPUModelRunner(
self.input_batch, num_scheduled_tokens, num_sampled_tokens
)
+ end_ts = time.time()
+ elapsed_ms = (end_ts - start_ts) * 1000
+ # print(f"_prepare_inputs elapsed_ms {elapsed_ms:.3f} ms")
return (
logits_indices,
spec_decode_metadata,
@@ -1426,6 +1481,7 @@ class GPUModelRunner(
for_cudagraph_capture: bool = False,
scheduled_encoder_inputs: dict[str, list[int]] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None,
+ follow_mega=False,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
@@ -1591,6 +1647,7 @@ class GPUModelRunner(
attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
+ follow_mega=follow_mega,
**extra_attn_metadata_args,
)
for layer_name in attn_group.layer_names:
@@ -2618,6 +2675,9 @@ class GPUModelRunner(
scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | IntermediateTensors | None:
+ # print(f"execute_model scheduler_output = {scheduler_output}")
+ from vllm.v1.core.sched.scheduler import seen_mega_step
+ follow_mega = seen_mega_step and not scheduler_output.is_mega_step
if self.execute_model_state is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
@@ -2718,6 +2778,7 @@ class GPUModelRunner(
use_spec_decode=use_spec_decode,
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
+ follow_mega=follow_mega,
)
)
@@ -2789,6 +2850,7 @@ class GPUModelRunner(
inputs_embeds=inputs_embeds,
**model_kwargs,
)
+ # if positions[0].item() == 16: breakpoint()
with record_function_or_nullcontext("gpu_model_runner: postprocess"):
if self.use_aux_hidden_state_outputs:
@@ -2904,6 +2966,7 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
+ self.input_ids.gpu[:logits.shape[0]] = sampler_output.sampled_token_ids.flatten()
self.input_batch.prev_sampled_token_ids = None
diff --git a/vllm/v1/worker/mega_step_runner.py b/vllm/v1/worker/mega_step_runner.py
new file mode 100644
index 000000000..d1a602ee5
--- /dev/null
+++ b/vllm/v1/worker/mega_step_runner.py
@@ -0,0 +1,254 @@
+import numpy as np
+import torch
+from vllm.forward_context import BatchDescriptor, set_forward_context
+from vllm.v1.outputs import ModelRunnerOutput
+from dataclasses import dataclass
+
+# TODO delete
+@dataclass
+class MegaStepOutput:
+ sampled_token_ids_list: list[list[list[int]]] # one list for each mini step
+
+class MegaStepRunner:
+ def __init__(self, engine_core: "EngineCore", scheduler_output):
+ self.engine_core = engine_core
+ self.scheduler_output = scheduler_output
+ self.model_executor = self.engine_core.model_executor
+ self.scheduler = self.engine_core.scheduler
+ self.worker = self.model_executor.driver_worker.worker
+ self.model_runner = self.worker.model_runner
+
+ # fields copied from model_runner
+ self.input_batch = self.model_runner.input_batch
+ self.input_ids = self.model_runner.input_ids
+ self.query_start_loc = self.model_runner.query_start_loc
+ self.seq_lens = self.model_runner.seq_lens
+ self.vllm_config = self.model_runner.vllm_config
+ self.cudagraph_dispatcher = self.model_runner.cudagraph_dispatcher
+
+ self.num_input_tokens = self.scheduler_output.total_num_scheduled_tokens
+ self.num_reqs = self.input_batch.num_reqs
+ self.req_ids = self.input_batch.req_ids
+
+ self.num_scheduled_tokens = np.array(
+ [scheduler_output.num_scheduled_tokens[i] for i in self.req_ids],
+ dtype=np.int32,
+ )
+
+ # TODO: find a better way to accumulate results
+ self.model_output_list = []
+
+ def _update_states(self):
+ req_data = self.scheduler_output.scheduled_cached_reqs
+ for i, req_id in enumerate(req_data.req_ids):
+ req_index = self.input_batch.req_id_to_index.get(req_id)
+ num_computed_tokens = req_data.num_computed_tokens[i]
+ self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
+
+ new_block_ids = req_data.new_block_ids[i]
+
+ if new_block_ids is not None:
+ self.input_batch.block_table.append_row(new_block_ids, req_index)
+
+ def _prepare_inputs(self, no_override_input):
+ self.input_batch.block_table.commit_block_table(self.num_reqs)
+
+ req_indices = np.repeat(self.model_runner.arange_np[:self.num_reqs], self.num_scheduled_tokens)
+ cu_num_tokens, arange = self.model_runner._get_cumsum_and_arange(self.num_scheduled_tokens)
+ positions_np = self.model_runner.positions.np[:self.num_input_tokens]
+ np.add(
+ self.input_batch.num_computed_tokens_cpu[req_indices],
+ arange,
+ out=positions_np,
+ )
+ self.model_runner.positions.copy_to_gpu(self.num_input_tokens)
+
+ self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
+ self.input_batch.block_table.commit_slot_mapping(self.scheduler_output.total_num_scheduled_tokens)
+
+ token_indices = (
+ positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
+ )
+ token_indices_tensor = torch.from_numpy(token_indices)
+
+ torch.index_select(
+ self.input_batch.token_ids_cpu_tensor.flatten(),
+ 0,
+ token_indices_tensor,
+ out=self.input_ids.cpu[:self.num_input_tokens],
+ )
+
+ if not no_override_input:
+ self.input_ids.copy_to_gpu(self.num_input_tokens)
+
+ self.query_start_loc.np[0] = 0
+ self.query_start_loc.np[1 : self.num_reqs + 1] = cu_num_tokens
+ self.query_start_loc.np[self.num_reqs + 1 :].fill(cu_num_tokens[-1])
+ self.query_start_loc.copy_to_gpu()
+
+ query_start_loc = self.query_start_loc.gpu[: self.num_reqs + 1]
+
+ self.seq_lens.np[:self.num_reqs] = self.input_batch.num_computed_tokens_cpu[:self.num_reqs] + self.num_scheduled_tokens
+ self.seq_lens.np[self.num_reqs:].fill(0)
+ self.seq_lens.copy_to_gpu()
+
+ self.logits_indices = query_start_loc[1:] - 1
+
+ # fields setup by the initial step
+ self.attn_metadata = None
+ self.cudagraph_runtime_mode = None
+ self.batch_descriptor = None
+ self.positions = None
+
+ def _dispatch_cudagraph(self):
+ batch_desc = BatchDescriptor(
+ num_tokens=self.num_input_tokens,
+ uniform_decode=True,
+ has_lora=False,
+ )
+ cudagraph_runtime_mode, batch_descriptor = (
+ self.cudagraph_dispatcher.dispatch(
+ batch_desc,
+ use_cascade_attn=False,
+ )
+ )
+
+ return cudagraph_runtime_mode, batch_descriptor
+
+ def _sample(self, hidden_states, logits):
+ sampler_output = self.model_runner._sample(logits, None)
+ self.input_ids.gpu[:logits.shape[0]] = sampler_output.sampled_token_ids.flatten()
+ (
+ num_nans_in_logits,
+ logprobs_lists,
+ valid_sampled_token_ids,
+ prompt_logprobs_dict,
+ req_ids_output_copy,
+ req_id_to_index_output_copy,
+ invalid_req_indices,
+ ) = self.model_runner._bookkeeping_sync(
+ self.scheduler_output,
+ sampler_output,
+ logits,
+ hidden_states,
+ self.scheduler_output.total_num_scheduled_tokens,
+ None,
+ )
+
+ return ModelRunnerOutput(
+ req_ids=req_ids_output_copy,
+ req_id_to_index=req_id_to_index_output_copy,
+ sampled_token_ids=valid_sampled_token_ids,
+ logprobs=logprobs_lists,
+ prompt_logprobs_dict={},
+ pooler_output=[],
+ kv_connector_output=None,
+ ec_connector_output=None,
+ num_nans_in_logits=num_nans_in_logits,
+ )
+
+ def initial_step(self, no_override_input=False):
+ self._update_states()
+ self._prepare_inputs(no_override_input)
+ input_ids = self.input_ids.gpu[:self.num_input_tokens]
+ self.positions = self.model_runner.positions.gpu[:self.num_input_tokens]
+
+ self.attn_metadata, _ = self.model_runner._build_attention_metadata(
+ total_num_scheduled_tokens=self.num_input_tokens,
+ max_num_scheduled_tokens=1,
+ num_reqs=self.num_reqs,
+ ubatch_slices=None,
+ logits_indices=self.logits_indices,
+ use_spec_decode=False,
+ scheduled_encoder_inputs=None,
+ cascade_attn_prefix_lens=0,
+ follow_mega=True,
+ )
+ self.cudagraph_runtime_mode, self.batch_descriptor = self._dispatch_cudagraph()
+
+ with (
+ set_forward_context(
+ self.attn_metadata,
+ self.vllm_config,
+ num_tokens=self.num_input_tokens,
+ num_tokens_across_dp=0,
+ cudagraph_runtime_mode=self.cudagraph_runtime_mode,
+ batch_descriptor=self.batch_descriptor,
+ ubatch_slices=None,
+ )
+ ):
+ # print(f"input_ids={input_ids}")
+ # print(f"positions={positions}")
+ model_output = self.model_runner._model_forward(
+ input_ids=input_ids,
+ positions=self.positions,
+ )
+ hidden_states = model_output
+ sample_hidden_states = hidden_states[self.logits_indices]
+ logits = self.model_runner.model.compute_logits(sample_hidden_states)
+
+ model_output: ModelRunnerOutput = self._sample(hidden_states, logits)
+
+ # breakpoint()
+ self.model_output_list.append(model_output)
+
+ def apply_delta(self):
+ """Apply delta for the following steps"""
+ if False:
+ req_data = self.scheduler_output.scheduled_cached_reqs
+ for i, req_id in enumerate(req_data.req_ids):
+ req_data.num_computed_tokens[i] += 1
+
+ self.positions += 1
+ # input ids is already setup in _sample
+ # update attn_metadata
+ # This assumes all value are the same object
+ meta_obj = next(iter(self.attn_metadata.values()))
+ meta_obj.slot_mapping += 1
+ meta_obj.max_seq_len += 1
+ meta_obj.seq_lens += 1
+
+ def following_step(self):
+ self.apply_delta()
+ with (
+ set_forward_context(
+ self.attn_metadata,
+ self.vllm_config,
+ num_tokens=self.num_input_tokens,
+ num_tokens_across_dp=0,
+ cudagraph_runtime_mode=self.cudagraph_runtime_mode,
+ batch_descriptor=self.batch_descriptor,
+ ubatch_slices=None,
+ )
+ ):
+ # print(f"input_ids={input_ids}")
+ # print(f"positions={positions}")
+ model_output = self.model_runner._model_forward(
+ input_ids=self.input_ids.gpu[:self.num_input_tokens],
+ positions=self.positions,
+ )
+ hidden_states = model_output
+ sample_hidden_states = hidden_states[self.logits_indices]
+ logits = self.model_runner.model.compute_logits(sample_hidden_states)
+
+ model_output: ModelRunnerOutput = self._sample(hidden_states, logits)
+
+ self.model_output_list.append(model_output)
+
+
+ def construct_output(self):
+ assert len(self.model_output_list) > 0
+ output = self.model_output_list[0]
+ sampled_token_ids = output.sampled_token_ids
+ for other in self.model_output_list[1:]:
+ other_token_ids = other.sampled_token_ids
+ for lhs, rhs in zip(sampled_token_ids, other_token_ids):
+ lhs.extend(rhs)
+ return output
+
+ def run(self):
+ self.initial_step()
+ for offset in range(1, self.scheduler.scheduler_config.mega_step_size):
+ self.following_step()
+
+ return self.construct_output()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment