Created
December 3, 2025 01:16
-
-
Save shunting314/faa70cfa93ee5da4b6996c5d7b523982 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py | |
| index b4f175183..b539fd1f8 100644 | |
| --- a/vllm/benchmarks/latency.py | |
| +++ b/vllm/benchmarks/latency.py | |
| @@ -101,7 +101,7 @@ def main(args: argparse.Namespace): | |
| sampling_params = SamplingParams( | |
| n=args.n, | |
| - temperature=1.0, | |
| + temperature=0.0, | |
| top_p=1.0, | |
| ignore_eos=True, | |
| max_tokens=args.output_len, | |
| diff --git a/vllm/config/model.py b/vllm/config/model.py | |
| index 49688e17c..18445845b 100644 | |
| --- a/vllm/config/model.py | |
| +++ b/vllm/config/model.py | |
| @@ -6,6 +6,7 @@ from collections.abc import Callable | |
| from dataclasses import InitVar, field | |
| from importlib.util import find_spec | |
| from typing import TYPE_CHECKING, Any, Literal, cast, get_args | |
| +import functools | |
| import torch | |
| from pydantic import ConfigDict, SkipValidation, field_validator, model_validator | |
| @@ -314,6 +315,8 @@ class ModelConfig: | |
| skip_mm_profiling: InitVar[bool | None] = None | |
| video_pruning_rate: InitVar[float | None] = None | |
| + _is_encoder_decoder = None | |
| + | |
| def compute_hash(self) -> str: | |
| """ | |
| WARNING: Whenever a new field is added to this config, | |
| @@ -1576,7 +1579,9 @@ class ModelConfig: | |
| @property | |
| def is_encoder_decoder(self) -> bool: | |
| """Extract the HF encoder/decoder model flag.""" | |
| - return is_encoder_decoder(self.hf_config) | |
| + if self._is_encoder_decoder is None: | |
| + self._is_encoder_decoder = is_encoder_decoder(self.hf_config) | |
| + return self._is_encoder_decoder | |
| @property | |
| def uses_alibi(self) -> bool: | |
| diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py | |
| index b6078706d..e5585800a 100644 | |
| --- a/vllm/config/scheduler.py | |
| +++ b/vllm/config/scheduler.py | |
| @@ -141,6 +141,8 @@ class SchedulerConfig: | |
| while a larger value (e.g., 10) reduces host overhead and may increase throughput | |
| by batching multiple tokens before sending.""" | |
| + mega_step_size = 16 | |
| + | |
| def get_scheduler_cls(self) -> type["SchedulerInterface"]: | |
| if self.scheduler_cls is None: | |
| if self.async_scheduling: | |
| diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py | |
| index 848916dbd..0e1c29bca 100644 | |
| --- a/vllm/entrypoints/llm.py | |
| +++ b/vllm/entrypoints/llm.py | |
| @@ -1718,6 +1718,8 @@ class LLM: | |
| def _run_engine( | |
| self, *, use_tqdm: bool | Callable[..., tqdm] = True | |
| ) -> list[RequestOutput | PoolingRequestOutput]: | |
| + import vllm.v1.core.sched.scheduler | |
| + vllm.v1.core.sched.scheduler.seen_mega_step = False | |
| # Initialize tqdm. | |
| if use_tqdm: | |
| num_requests = self.llm_engine.get_num_unfinished_requests() | |
| @@ -1764,4 +1766,15 @@ class LLM: | |
| # Sort the outputs by request ID. | |
| # This is necessary because some requests may be finished earlier than | |
| # its previous requests. | |
| + if True: | |
| + # compute token hash | |
| + myhash = 0 | |
| + for output in outputs: | |
| + assert len(output.outputs) == 1 | |
| + out = output.outputs[0] | |
| + for tok in out.token_ids: | |
| + myhash = (myhash * 23 + tok) % 1000000007 | |
| + print(f"myhash is {myhash}") | |
| + # breakpoint() | |
| + | |
| return sorted(outputs, key=lambda x: int(x.request_id)) | |
| diff --git a/vllm/envs.py b/vllm/envs.py | |
| index 9b1ed1fc6..0d171581e 100755 | |
| --- a/vllm/envs.py | |
| +++ b/vllm/envs.py | |
| @@ -1527,6 +1527,9 @@ environment_variables: dict[str, Callable[[], Any]] = { | |
| "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( | |
| int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) | |
| ), | |
| + "VLLM_MEGA_STEP": lambda: bool( | |
| + int(os.getenv("VLLM_MEGA_STEP", "1")) | |
| + ), | |
| } | |
| # --8<-- [end:env-vars-definition] | |
| diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py | |
| index e3f499216..71d975a0c 100755 | |
| --- a/vllm/v1/attention/backends/flashinfer.py | |
| +++ b/vllm/v1/attention/backends/flashinfer.py | |
| @@ -659,6 +659,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): | |
| common_prefix_len: int, | |
| common_attn_metadata: CommonAttentionMetadata, | |
| fast_build: bool = False, | |
| + follow_mega=False, | |
| ) -> FlashInferMetadata: | |
| num_reqs = common_attn_metadata.num_reqs | |
| num_actual_tokens = common_attn_metadata.num_actual_tokens | |
| @@ -698,6 +699,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): | |
| ) | |
| seq_lens_np = seq_lens_cpu.numpy() | |
| + # XXX nop if follow_mega | |
| num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size | |
| use_cascade = common_prefix_len > 0 | |
| @@ -728,34 +730,42 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): | |
| shared_kv_last_page_len_cpu = None | |
| # write self.paged_kv_indptr_cpu inplace (0-index is always 0) | |
| - np.cumsum( | |
| - num_blocks_np, | |
| - dtype=np.int32, | |
| - out=self.paged_kv_indptr_np[1 : num_reqs + 1], | |
| - ) | |
| + if not follow_mega: | |
| + np.cumsum( | |
| + num_blocks_np, | |
| + dtype=np.int32, | |
| + out=self.paged_kv_indptr_np[1 : num_reqs + 1], | |
| + ) | |
| + | |
| # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified | |
| # after this line (e.g., for cuda graphs), we need to copy the data to | |
| # self.paged_kv_indptr_buffer to avoid race condition. | |
| - self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ | |
| - : num_reqs + 1 | |
| - ] | |
| + if not follow_mega: | |
| + self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ | |
| + : num_reqs + 1 | |
| + ] | |
| paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] | |
| - paged_kv_indptr.copy_( | |
| - self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True | |
| - ) | |
| + # XXX nop if follow mega | |
| + if not follow_mega: | |
| + paged_kv_indptr.copy_( | |
| + self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True | |
| + ) | |
| # write self.paged_kv_indices inplace | |
| num_actual_pages = self.paged_kv_indptr_np[num_reqs] | |
| paged_kv_indices = self.paged_kv_indices[:num_actual_pages] | |
| - _copy_page_indices_kernel[(num_reqs,)]( | |
| - paged_kv_indices, | |
| - block_table_tensor, | |
| - block_table_tensor.stride(0), | |
| - paged_kv_indptr, | |
| - BLOCK_SIZE=1024, | |
| - ) | |
| + # XXX nop if follow mega | |
| + if not follow_mega: | |
| + _copy_page_indices_kernel[(num_reqs,)]( | |
| + paged_kv_indices, | |
| + block_table_tensor, | |
| + block_table_tensor.stride(0), | |
| + paged_kv_indptr, | |
| + BLOCK_SIZE=1024, | |
| + ) | |
| # write self.paged_kv_last_page_len_cpu inplace | |
| + # XXX plus 1 for follow mega | |
| paged_kv_last_page_len_np = seq_lens_np % page_size | |
| self.paged_kv_last_page_len_np[:num_reqs] = np.where( | |
| (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0), | |
| diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py | |
| index 7902513dc..7d3fcb6e2 100644 | |
| --- a/vllm/v1/core/sched/output.py | |
| +++ b/vllm/v1/core/sched/output.py | |
| @@ -188,6 +188,8 @@ class SchedulerOutput: | |
| # list of mm_hash strings associated with the encoder outputs to be | |
| # freed from the encoder cache. | |
| free_encoder_mm_hashes: list[str] | |
| + is_mega_step: bool | |
| + num_computed_tokens_offset: int | |
| # Request IDs that are preempted in this step. | |
| # Only used for v2 model runner. | |
| diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py | |
| index 4cb5348cb..49b534332 100644 | |
| --- a/vllm/v1/core/sched/scheduler.py | |
| +++ b/vllm/v1/core/sched/scheduler.py | |
| @@ -49,6 +49,7 @@ from vllm.v1.utils import record_function_or_nullcontext | |
| logger = init_logger(__name__) | |
| +seen_mega_step = False | |
| class Scheduler(SchedulerInterface): | |
| def __init__( | |
| @@ -191,6 +192,8 @@ class Scheduler(SchedulerInterface): | |
| self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER | |
| def schedule(self) -> SchedulerOutput: | |
| + start_ts = time.time() | |
| + global seen_mega_step | |
| # NOTE(woosuk) on the scheduling algorithm: | |
| # There's no "decoding phase" nor "prefill phase" in the scheduler. | |
| # Each request just has the num_computed_tokens and | |
| @@ -221,8 +224,11 @@ class Scheduler(SchedulerInterface): | |
| # First, schedule the RUNNING requests. | |
| req_index = 0 | |
| + is_mega_step = True | |
| + num_computed_tokens_offset = -1 | |
| while req_index < len(self.running) and token_budget > 0: | |
| request = self.running[req_index] | |
| + # if req_index == 0: print(f"schedule, req0 {request.num_computed_tokens=}") | |
| num_new_tokens = ( | |
| request.num_tokens_with_spec | |
| @@ -279,6 +285,9 @@ class Scheduler(SchedulerInterface): | |
| # Schedule newly needed KV blocks for the request. | |
| with record_function_or_nullcontext("schedule: allocate_slots"): | |
| while True: | |
| + if seen_mega_step and request.num_computed_tokens % self.vllm_config.cache_config.block_size != 0: | |
| + new_blocks = self.kv_cache_manager.empty_kv_cache_blocks | |
| + break | |
| new_blocks = self.kv_cache_manager.allocate_slots( | |
| request, | |
| num_new_tokens, | |
| @@ -343,6 +352,10 @@ class Scheduler(SchedulerInterface): | |
| break | |
| # Schedule the request. | |
| + if request.num_computed_tokens % self.scheduler_config.mega_step_size != 0 or num_new_tokens != 1: | |
| + num_computed_tokens_offset = request.num_computed_tokens % self.scheduler_config.mega_step_size | |
| + is_mega_step = False | |
| + | |
| scheduled_running_reqs.append(request) | |
| req_to_new_blocks[request.request_id] = new_blocks | |
| num_scheduled_tokens[request.request_id] = num_new_tokens | |
| @@ -592,6 +605,10 @@ class Scheduler(SchedulerInterface): | |
| req_index += 1 | |
| self.running.append(request) | |
| + | |
| + if request.num_computed_tokens % self.scheduler_config.mega_step_size != 0 or num_new_tokens != 1: | |
| + num_computed_tokens_offset = request.num_computed_tokens % self.scheduler_config.mega_step_size | |
| + is_mega_step = False | |
| if self.log_stats: | |
| request.record_event( | |
| EngineCoreEventType.SCHEDULED, scheduled_timestamp | |
| @@ -692,6 +709,9 @@ class Scheduler(SchedulerInterface): | |
| self.prev_step_scheduled_req_ids.clear() | |
| self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) | |
| + if is_mega_step: | |
| + seen_mega_step = True | |
| + | |
| scheduler_output = SchedulerOutput( | |
| scheduled_new_reqs=new_reqs_data, | |
| scheduled_cached_reqs=cached_reqs_data, | |
| @@ -707,6 +727,8 @@ class Scheduler(SchedulerInterface): | |
| # the previous and the current steps. | |
| finished_req_ids=self.finished_req_ids, | |
| free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), | |
| + is_mega_step=is_mega_step, | |
| + num_computed_tokens_offset=num_computed_tokens_offset, | |
| ) | |
| # NOTE(Kuntai): this function is designed for multiple purposes: | |
| @@ -728,6 +750,10 @@ class Scheduler(SchedulerInterface): | |
| with record_function_or_nullcontext("schedule: update_after_schedule"): | |
| self._update_after_schedule(scheduler_output) | |
| + | |
| + end_ts = time.time() | |
| + elapsed_ms = (end_ts - start_ts) * 1000 | |
| + # print(f"schedule {elapsed_ms:.3f} ms") | |
| return scheduler_output | |
| def _update_after_schedule( | |
| @@ -746,6 +772,11 @@ class Scheduler(SchedulerInterface): | |
| num_scheduled_tokens = scheduler_output.num_scheduled_tokens | |
| for req_id, num_scheduled_token in num_scheduled_tokens.items(): | |
| request = self.requests[req_id] | |
| + if False: | |
| + if envs.VLLM_MEGA_STEP and scheduler_output.is_mega_step: | |
| + request.num_computed_tokens += self.scheduler_config.mega_step_size | |
| + else: | |
| + request.num_computed_tokens += num_scheduled_token | |
| request.num_computed_tokens += num_scheduled_token | |
| # NOTE: _free_encoder_inputs relies on num_computed_tokens, which | |
| @@ -1028,6 +1059,10 @@ class Scheduler(SchedulerInterface): | |
| # in pipeline parallelism). | |
| continue | |
| + if envs.VLLM_MEGA_STEP and scheduler_output.is_mega_step: | |
| + request.num_computed_tokens += self.scheduler_config.mega_step_size - 1 | |
| + | |
| + | |
| req_index = model_runner_output.req_id_to_index[req_id] | |
| generated_token_ids = ( | |
| sampled_token_ids[req_index] if sampled_token_ids else [] | |
| diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py | |
| index 8657a95b5..2f0cf147d 100644 | |
| --- a/vllm/v1/engine/core.py | |
| +++ b/vllm/v1/engine/core.py | |
| @@ -16,6 +16,7 @@ from typing import Any, TypeVar, cast | |
| import msgspec | |
| import zmq | |
| +from vllm import envs | |
| from vllm.config import ParallelConfig, VllmConfig | |
| from vllm.distributed import stateless_destroy_torch_distributed_process_group | |
| from vllm.envs import enable_envs_cache | |
| @@ -327,6 +328,12 @@ class EngineCore: | |
| return callback | |
| + def mega_step(self, scheduler_output): | |
| + # run execute_model | |
| + from vllm.v1.worker.mega_step_runner import MegaStepRunner | |
| + runner = MegaStepRunner(self, scheduler_output) | |
| + return runner.run() | |
| + | |
| def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: | |
| """Schedule, execute, and make output. | |
| @@ -338,17 +345,22 @@ class EngineCore: | |
| # or finished and not yet removed from the batch. | |
| if not self.scheduler.has_requests(): | |
| return {}, False | |
| + | |
| scheduler_output = self.scheduler.schedule() | |
| - future = self.model_executor.execute_model(scheduler_output, non_block=True) | |
| - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) | |
| - with self.log_error_detail(scheduler_output): | |
| - model_output = future.result() | |
| - if model_output is None: | |
| - model_output = self.model_executor.sample_tokens(grammar_output) | |
| + if envs.VLLM_MEGA_STEP and scheduler_output.is_mega_step: | |
| + model_output = self.mega_step(scheduler_output) | |
| + else: | |
| + future = self.model_executor.execute_model(scheduler_output, non_block=True) | |
| + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) | |
| + with self.log_error_detail(scheduler_output): | |
| + model_output = future.result() | |
| + if model_output is None: | |
| + model_output = self.model_executor.sample_tokens(grammar_output) | |
| engine_core_outputs = self.scheduler.update_from_output( | |
| scheduler_output, model_output | |
| ) | |
| + # print(f"EngineCore.step: {model_output.sampled_token_ids=}") # TODO | |
| return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 | |
| diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py | |
| index e403cea87..605964f30 100644 | |
| --- a/vllm/v1/engine/llm_engine.py | |
| +++ b/vllm/v1/engine/llm_engine.py | |
| @@ -292,21 +292,22 @@ class LLMEngine: | |
| engine_core_timestamp=outputs.timestamp, | |
| iteration_stats=iteration_stats, | |
| ) | |
| - self.output_processor.update_scheduler_stats(outputs.scheduler_stats) | |
| + # self.output_processor.update_scheduler_stats(outputs.scheduler_stats) | |
| # 3) Abort any reqs that finished due to stop strings. | |
| with record_function_or_nullcontext("llm_engine step: abort_requests"): | |
| self.engine_core.abort_requests(processed_outputs.reqs_to_abort) | |
| # 4) Record stats | |
| - with record_function_or_nullcontext("llm_engine step: record_stats"): | |
| - if self.logger_manager is not None and outputs.scheduler_stats is not None: | |
| - self.logger_manager.record( | |
| - scheduler_stats=outputs.scheduler_stats, | |
| - iteration_stats=iteration_stats, | |
| - mm_cache_stats=self.processor.stat_mm_cache(), | |
| - ) | |
| - self.do_log_stats_with_interval() | |
| + if False: | |
| + with record_function_or_nullcontext("llm_engine step: record_stats"): | |
| + if self.logger_manager is not None and outputs.scheduler_stats is not None: | |
| + self.logger_manager.record( | |
| + scheduler_stats=outputs.scheduler_stats, | |
| + iteration_stats=iteration_stats, | |
| + mm_cache_stats=self.processor.stat_mm_cache(), | |
| + ) | |
| + self.do_log_stats_with_interval() | |
| return processed_outputs.request_outputs | |
| diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py | |
| index e786cd8bc..98bec2171 100644 | |
| --- a/vllm/v1/worker/gpu_model_runner.py | |
| +++ b/vllm/v1/worker/gpu_model_runner.py | |
| @@ -580,6 +580,9 @@ class GPUModelRunner( | |
| self.execute_model_state: ExecuteModelState | None = None | |
| self.kv_connector_output: KVConnectorOutput | None = None | |
| + self.saved_num_sampled_tokens = None | |
| + self.saved_logits_indices = None | |
| + | |
| def reset_mm_cache(self) -> None: | |
| if self.mm_budget: | |
| self.mm_budget.reset_cache() | |
| @@ -684,6 +687,7 @@ class GPUModelRunner( | |
| The SamplingMetadata is updated and copied to the GPU if there is a | |
| new/resumed/paused/finished request in the batch. | |
| """ | |
| + start_ts = time.time() | |
| # Remove finished requests from the cached states. | |
| for req_id in scheduler_output.finished_req_ids: | |
| self.requests.pop(req_id, None) | |
| @@ -769,6 +773,8 @@ class GPUModelRunner( | |
| # then use it to update actual num_computed_tokens of each request. | |
| valid_sampled_token_count = self._get_valid_sampled_token_count() | |
| + from vllm.v1.core.sched.scheduler import seen_mega_step | |
| + follow_mega = seen_mega_step and not scheduler_output.is_mega_step | |
| for i, req_id in enumerate(req_data.req_ids): | |
| req_state = self.requests[req_id] | |
| num_computed_tokens = req_data.num_computed_tokens[i] | |
| @@ -776,6 +782,10 @@ class GPUModelRunner( | |
| resumed_from_preemption = req_id in req_data.resumed_req_ids | |
| num_output_tokens = req_data.num_output_tokens[i] | |
| req_index = self.input_batch.req_id_to_index.get(req_id) | |
| + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens | |
| + | |
| + if follow_mega: | |
| + continue | |
| # prev_num_draft_len is used in async scheduling mode with | |
| # spec decode. it indicates if need to update num_computed_tokens | |
| @@ -859,7 +869,6 @@ class GPUModelRunner( | |
| continue | |
| # Update the persistent batch. | |
| - self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens | |
| if new_block_ids is not None: | |
| self.input_batch.block_table.append_row(new_block_ids, req_index) | |
| @@ -914,13 +923,18 @@ class GPUModelRunner( | |
| for request in reqs_to_add: | |
| self.input_batch.add_request(request) | |
| - # Condense the batched states if there are gaps left by removed requests | |
| - self.input_batch.condense() | |
| - # Allow attention backend to reorder the batch, potentially | |
| - self._may_reorder_batch(scheduler_output) | |
| + if not follow_mega: | |
| + # Condense the batched states if there are gaps left by removed requests | |
| + self.input_batch.condense() | |
| + # Allow attention backend to reorder the batch, potentially | |
| + self._may_reorder_batch(scheduler_output) | |
| # Refresh batch metadata with any pending updates. | |
| self.input_batch.refresh_metadata() | |
| + end_ts = time.time() | |
| + elapsed_ms = (end_ts - start_ts) * 1000 | |
| + # print(f"_update_states elapsed {elapsed_ms:.3f} ms") | |
| + | |
| def _update_states_after_model_execute( | |
| self, output_token_ids: torch.Tensor | |
| ) -> None: | |
| @@ -1188,6 +1202,8 @@ class GPUModelRunner( | |
| ubatch_slices, num_tokens_across_dp, | |
| ] | |
| """ | |
| + start_ts = time.time() | |
| + from vllm.v1.core.sched.scheduler import seen_mega_step | |
| total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens | |
| assert total_num_scheduled_tokens > 0 | |
| num_reqs = self.input_batch.num_reqs | |
| @@ -1195,23 +1211,31 @@ class GPUModelRunner( | |
| # OPTIMIZATION: Start copying the block table first. | |
| # This way, we can overlap the copy with the following CPU operations. | |
| - self.input_batch.block_table.commit_block_table(num_reqs) | |
| - | |
| - # Get request indices. | |
| - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] | |
| - req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) | |
| - | |
| - # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] | |
| - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] | |
| - cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) | |
| - | |
| - # Get positions. | |
| - positions_np = self.positions.np[:total_num_scheduled_tokens] | |
| - np.add( | |
| - self.input_batch.num_computed_tokens_cpu[req_indices], | |
| - arange, | |
| - out=positions_np, | |
| - ) | |
| + follow_mega = seen_mega_step and not scheduler_output.is_mega_step | |
| + if follow_mega: | |
| + # assert torch.equal(self.input_batch.block_table[0].block_table.cpu[:num_reqs].cuda(), self.input_batch.block_table[0].block_table.gpu[:num_reqs]) | |
| + # print("Can skip commiting block table") | |
| + pass | |
| + else: | |
| + self.input_batch.block_table.commit_block_table(num_reqs) | |
| + | |
| + if not follow_mega: | |
| + # Get request indices. | |
| + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] | |
| + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) | |
| + | |
| + # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] | |
| + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] | |
| + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) | |
| + | |
| + # Get positions. | |
| + # XXX following step can just use base version + offset | |
| + positions_np = self.positions.np[:total_num_scheduled_tokens] | |
| + np.add( | |
| + self.input_batch.num_computed_tokens_cpu[req_indices], | |
| + arange, | |
| + out=positions_np, | |
| + ) | |
| # Calculate M-RoPE positions. | |
| # Only relevant for models using M-RoPE (e.g, Qwen2-VL) | |
| @@ -1222,20 +1246,22 @@ class GPUModelRunner( | |
| # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] | |
| # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] | |
| # where M is the max_model_len. | |
| - token_indices = ( | |
| - positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] | |
| - ) | |
| - token_indices_tensor = torch.from_numpy(token_indices) | |
| - | |
| - # NOTE(woosuk): We use torch.index_select instead of np.take here | |
| - # because torch.index_select is much faster than np.take for large | |
| - # tensors. | |
| - torch.index_select( | |
| - self.input_batch.token_ids_cpu_tensor.flatten(), | |
| - 0, | |
| - token_indices_tensor, | |
| - out=self.input_ids.cpu[:total_num_scheduled_tokens], | |
| - ) | |
| + if not follow_mega: | |
| + token_indices = ( | |
| + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] | |
| + ) | |
| + token_indices_tensor = torch.from_numpy(token_indices) | |
| + | |
| + # NOTE(woosuk): We use torch.index_select instead of np.take here | |
| + # because torch.index_select is much faster than np.take for large | |
| + # tensors. | |
| + torch.index_select( | |
| + self.input_batch.token_ids_cpu_tensor.flatten(), | |
| + 0, | |
| + token_indices_tensor, | |
| + out=self.input_ids.cpu[:total_num_scheduled_tokens], | |
| + ) | |
| + | |
| if self.enable_prompt_embeds: | |
| is_token_ids = self.input_batch.is_token_ids_tensor.flatten() | |
| torch.index_select( | |
| @@ -1283,23 +1309,34 @@ class GPUModelRunner( | |
| output_idx += num_sched | |
| - self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) | |
| - self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) | |
| + # XXX following step can just use base version + offset | |
| + if not follow_mega: | |
| + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) | |
| + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) | |
| + else: | |
| + # dif = self.input_batch.block_table[0].slot_mapping.cpu[:total_num_scheduled_tokens].cuda() - self.input_batch.block_table[0].slot_mapping.gpu[:total_num_scheduled_tokens] | |
| + # assert torch.all(dif == scheduler_output.num_computed_tokens_offset).item() | |
| + self.input_batch.block_table[0].slot_mapping.gpu[:total_num_scheduled_tokens] += 1 | |
| + | |
| + if not follow_mega: | |
| + # Prepare the attention metadata. | |
| + self.query_start_loc.np[0] = 0 | |
| + self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens | |
| + # Note: pad query_start_loc to be non-decreasing, as kernels | |
| + # like FlashAttention requires that | |
| + self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1]) | |
| + self.query_start_loc.copy_to_gpu() | |
| - # Prepare the attention metadata. | |
| - self.query_start_loc.np[0] = 0 | |
| - self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens | |
| - # Note: pad query_start_loc to be non-decreasing, as kernels | |
| - # like FlashAttention requires that | |
| - self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1]) | |
| - self.query_start_loc.copy_to_gpu() | |
| query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] | |
| num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens | |
| num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded) | |
| - uniform_decode = ( | |
| - max_num_scheduled_tokens == self.uniform_decode_query_len | |
| - ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) | |
| + if not follow_mega: | |
| + uniform_decode = ( | |
| + max_num_scheduled_tokens == self.uniform_decode_query_len | |
| + ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) | |
| + else: | |
| + uniform_decode = True | |
| # Disable DP padding when running eager to avoid excessive padding when | |
| # running prefills. This lets us set enforce_eager on the prefiller in | |
| @@ -1317,33 +1354,41 @@ class GPUModelRunner( | |
| num_scheduled_tokens_per_request=num_scheduled_tokens, | |
| ) | |
| - self.seq_lens.np[:num_reqs] = ( | |
| - self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens | |
| - ) | |
| - # Fill unused with 0 for full cuda graph mode. | |
| - self.seq_lens.np[num_reqs:].fill(0) | |
| - self.seq_lens.copy_to_gpu() | |
| + if follow_mega: | |
| + self.seq_lens.gpu[:num_reqs] += 1 | |
| + else: | |
| + self.seq_lens.np[:num_reqs] = ( | |
| + self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens | |
| + ) | |
| + # Fill unused with 0 for full cuda graph mode. | |
| + self.seq_lens.np[num_reqs:].fill(0) | |
| + self.seq_lens.copy_to_gpu() | |
| - num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] | |
| - num_tokens_np = np.array(num_tokens, dtype=np.int32) | |
| # Record the index of requests that should not be sampled, | |
| # so that we could clear the sampled tokens before returning | |
| - discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np | |
| - discard_request_indices = np.nonzero(discard_requests_mask)[0] | |
| - self.num_discarded_requests = len(discard_request_indices) | |
| - self.discard_request_indices.np[: self.num_discarded_requests] = ( | |
| - discard_request_indices | |
| - ) | |
| - | |
| - self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) | |
| + if not follow_mega: | |
| + num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] | |
| + num_tokens_np = np.array(num_tokens, dtype=np.int32) | |
| + discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np | |
| + discard_request_indices = np.nonzero(discard_requests_mask)[0] | |
| + self.num_discarded_requests = len(discard_request_indices) | |
| + self.discard_request_indices.np[: self.num_discarded_requests] = ( | |
| + discard_request_indices | |
| + ) | |
| + | |
| + self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) | |
| # Copy the tensors to the GPU. | |
| - self._prepare_input_ids( | |
| - scheduler_output, | |
| - total_num_scheduled_tokens, | |
| - cu_num_tokens, | |
| - ) | |
| + if False: | |
| + self._prepare_input_ids( | |
| + scheduler_output, | |
| + total_num_scheduled_tokens, | |
| + cu_num_tokens, | |
| + ) | |
| + | |
| + if not follow_mega: | |
| + self.input_ids.copy_to_gpu(total_num_scheduled_tokens) | |
| if self.uses_mrope: | |
| # Only relevant for models using M-RoPE (e.g, Qwen2-VL) | |
| @@ -1353,7 +1398,10 @@ class GPUModelRunner( | |
| ) | |
| else: | |
| # Common case (1D positions) | |
| - self.positions.copy_to_gpu(total_num_scheduled_tokens) | |
| + if follow_mega: | |
| + self.positions.gpu[:total_num_scheduled_tokens] += 1 | |
| + else: | |
| + self.positions.copy_to_gpu(total_num_scheduled_tokens) | |
| use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 | |
| if not use_spec_decode: | |
| @@ -1362,10 +1410,14 @@ class GPUModelRunner( | |
| # from these partial requests, we do so for simplicity. | |
| # We will ignore the sampled tokens from the partial requests. | |
| # TODO: Support prompt logprobs. | |
| - logits_indices = query_start_loc[1:] - 1 | |
| + if not follow_mega: | |
| + self.saved_logits_indices = logits_indices = query_start_loc[1:] - 1 | |
| + self.saved_num_sampled_tokens = num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) | |
| + else: | |
| + logits_indices = self.saved_logits_indices | |
| + num_sampled_tokens = self.saved_num_sampled_tokens | |
| num_draft_tokens = None | |
| spec_decode_metadata = None | |
| - num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) | |
| else: | |
| # Get the number of draft tokens for each request. | |
| # Iterate over the dictionary rather than all requests since not all | |
| @@ -1408,6 +1460,9 @@ class GPUModelRunner( | |
| self.input_batch, num_scheduled_tokens, num_sampled_tokens | |
| ) | |
| + end_ts = time.time() | |
| + elapsed_ms = (end_ts - start_ts) * 1000 | |
| + # print(f"_prepare_inputs elapsed_ms {elapsed_ms:.3f} ms") | |
| return ( | |
| logits_indices, | |
| spec_decode_metadata, | |
| @@ -1426,6 +1481,7 @@ class GPUModelRunner( | |
| for_cudagraph_capture: bool = False, | |
| scheduled_encoder_inputs: dict[str, list[int]] | None = None, | |
| cascade_attn_prefix_lens: list[list[int]] | None = None, | |
| + follow_mega=False, | |
| ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: | |
| """ | |
| :return: tuple[attn_metadata, spec_decode_common_attn_metadata] | |
| @@ -1591,6 +1647,7 @@ class GPUModelRunner( | |
| attn_metadata_i = builder.build( | |
| common_prefix_len=cascade_attn_prefix_len, | |
| common_attn_metadata=common_attn_metadata, | |
| + follow_mega=follow_mega, | |
| **extra_attn_metadata_args, | |
| ) | |
| for layer_name in attn_group.layer_names: | |
| @@ -2618,6 +2675,9 @@ class GPUModelRunner( | |
| scheduler_output: "SchedulerOutput", | |
| intermediate_tensors: IntermediateTensors | None = None, | |
| ) -> ModelRunnerOutput | IntermediateTensors | None: | |
| + # print(f"execute_model scheduler_output = {scheduler_output}") | |
| + from vllm.v1.core.sched.scheduler import seen_mega_step | |
| + follow_mega = seen_mega_step and not scheduler_output.is_mega_step | |
| if self.execute_model_state is not None: | |
| raise RuntimeError( | |
| "State error: sample_tokens() must be called " | |
| @@ -2718,6 +2778,7 @@ class GPUModelRunner( | |
| use_spec_decode=use_spec_decode, | |
| scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, | |
| cascade_attn_prefix_lens=cascade_attn_prefix_lens, | |
| + follow_mega=follow_mega, | |
| ) | |
| ) | |
| @@ -2789,6 +2850,7 @@ class GPUModelRunner( | |
| inputs_embeds=inputs_embeds, | |
| **model_kwargs, | |
| ) | |
| + # if positions[0].item() == 16: breakpoint() | |
| with record_function_or_nullcontext("gpu_model_runner: postprocess"): | |
| if self.use_aux_hidden_state_outputs: | |
| @@ -2904,6 +2966,7 @@ class GPUModelRunner( | |
| with record_function_or_nullcontext("gpu_model_runner: sample"): | |
| sampler_output = self._sample(logits, spec_decode_metadata) | |
| + self.input_ids.gpu[:logits.shape[0]] = sampler_output.sampled_token_ids.flatten() | |
| self.input_batch.prev_sampled_token_ids = None | |
| diff --git a/vllm/v1/worker/mega_step_runner.py b/vllm/v1/worker/mega_step_runner.py | |
| new file mode 100644 | |
| index 000000000..d1a602ee5 | |
| --- /dev/null | |
| +++ b/vllm/v1/worker/mega_step_runner.py | |
| @@ -0,0 +1,254 @@ | |
| +import numpy as np | |
| +import torch | |
| +from vllm.forward_context import BatchDescriptor, set_forward_context | |
| +from vllm.v1.outputs import ModelRunnerOutput | |
| +from dataclasses import dataclass | |
| + | |
| +# TODO delete | |
| +@dataclass | |
| +class MegaStepOutput: | |
| + sampled_token_ids_list: list[list[list[int]]] # one list for each mini step | |
| + | |
| +class MegaStepRunner: | |
| + def __init__(self, engine_core: "EngineCore", scheduler_output): | |
| + self.engine_core = engine_core | |
| + self.scheduler_output = scheduler_output | |
| + self.model_executor = self.engine_core.model_executor | |
| + self.scheduler = self.engine_core.scheduler | |
| + self.worker = self.model_executor.driver_worker.worker | |
| + self.model_runner = self.worker.model_runner | |
| + | |
| + # fields copied from model_runner | |
| + self.input_batch = self.model_runner.input_batch | |
| + self.input_ids = self.model_runner.input_ids | |
| + self.query_start_loc = self.model_runner.query_start_loc | |
| + self.seq_lens = self.model_runner.seq_lens | |
| + self.vllm_config = self.model_runner.vllm_config | |
| + self.cudagraph_dispatcher = self.model_runner.cudagraph_dispatcher | |
| + | |
| + self.num_input_tokens = self.scheduler_output.total_num_scheduled_tokens | |
| + self.num_reqs = self.input_batch.num_reqs | |
| + self.req_ids = self.input_batch.req_ids | |
| + | |
| + self.num_scheduled_tokens = np.array( | |
| + [scheduler_output.num_scheduled_tokens[i] for i in self.req_ids], | |
| + dtype=np.int32, | |
| + ) | |
| + | |
| + # TODO: find a better way to accumulate results | |
| + self.model_output_list = [] | |
| + | |
| + def _update_states(self): | |
| + req_data = self.scheduler_output.scheduled_cached_reqs | |
| + for i, req_id in enumerate(req_data.req_ids): | |
| + req_index = self.input_batch.req_id_to_index.get(req_id) | |
| + num_computed_tokens = req_data.num_computed_tokens[i] | |
| + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens | |
| + | |
| + new_block_ids = req_data.new_block_ids[i] | |
| + | |
| + if new_block_ids is not None: | |
| + self.input_batch.block_table.append_row(new_block_ids, req_index) | |
| + | |
| + def _prepare_inputs(self, no_override_input): | |
| + self.input_batch.block_table.commit_block_table(self.num_reqs) | |
| + | |
| + req_indices = np.repeat(self.model_runner.arange_np[:self.num_reqs], self.num_scheduled_tokens) | |
| + cu_num_tokens, arange = self.model_runner._get_cumsum_and_arange(self.num_scheduled_tokens) | |
| + positions_np = self.model_runner.positions.np[:self.num_input_tokens] | |
| + np.add( | |
| + self.input_batch.num_computed_tokens_cpu[req_indices], | |
| + arange, | |
| + out=positions_np, | |
| + ) | |
| + self.model_runner.positions.copy_to_gpu(self.num_input_tokens) | |
| + | |
| + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) | |
| + self.input_batch.block_table.commit_slot_mapping(self.scheduler_output.total_num_scheduled_tokens) | |
| + | |
| + token_indices = ( | |
| + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] | |
| + ) | |
| + token_indices_tensor = torch.from_numpy(token_indices) | |
| + | |
| + torch.index_select( | |
| + self.input_batch.token_ids_cpu_tensor.flatten(), | |
| + 0, | |
| + token_indices_tensor, | |
| + out=self.input_ids.cpu[:self.num_input_tokens], | |
| + ) | |
| + | |
| + if not no_override_input: | |
| + self.input_ids.copy_to_gpu(self.num_input_tokens) | |
| + | |
| + self.query_start_loc.np[0] = 0 | |
| + self.query_start_loc.np[1 : self.num_reqs + 1] = cu_num_tokens | |
| + self.query_start_loc.np[self.num_reqs + 1 :].fill(cu_num_tokens[-1]) | |
| + self.query_start_loc.copy_to_gpu() | |
| + | |
| + query_start_loc = self.query_start_loc.gpu[: self.num_reqs + 1] | |
| + | |
| + self.seq_lens.np[:self.num_reqs] = self.input_batch.num_computed_tokens_cpu[:self.num_reqs] + self.num_scheduled_tokens | |
| + self.seq_lens.np[self.num_reqs:].fill(0) | |
| + self.seq_lens.copy_to_gpu() | |
| + | |
| + self.logits_indices = query_start_loc[1:] - 1 | |
| + | |
| + # fields setup by the initial step | |
| + self.attn_metadata = None | |
| + self.cudagraph_runtime_mode = None | |
| + self.batch_descriptor = None | |
| + self.positions = None | |
| + | |
| + def _dispatch_cudagraph(self): | |
| + batch_desc = BatchDescriptor( | |
| + num_tokens=self.num_input_tokens, | |
| + uniform_decode=True, | |
| + has_lora=False, | |
| + ) | |
| + cudagraph_runtime_mode, batch_descriptor = ( | |
| + self.cudagraph_dispatcher.dispatch( | |
| + batch_desc, | |
| + use_cascade_attn=False, | |
| + ) | |
| + ) | |
| + | |
| + return cudagraph_runtime_mode, batch_descriptor | |
| + | |
| + def _sample(self, hidden_states, logits): | |
| + sampler_output = self.model_runner._sample(logits, None) | |
| + self.input_ids.gpu[:logits.shape[0]] = sampler_output.sampled_token_ids.flatten() | |
| + ( | |
| + num_nans_in_logits, | |
| + logprobs_lists, | |
| + valid_sampled_token_ids, | |
| + prompt_logprobs_dict, | |
| + req_ids_output_copy, | |
| + req_id_to_index_output_copy, | |
| + invalid_req_indices, | |
| + ) = self.model_runner._bookkeeping_sync( | |
| + self.scheduler_output, | |
| + sampler_output, | |
| + logits, | |
| + hidden_states, | |
| + self.scheduler_output.total_num_scheduled_tokens, | |
| + None, | |
| + ) | |
| + | |
| + return ModelRunnerOutput( | |
| + req_ids=req_ids_output_copy, | |
| + req_id_to_index=req_id_to_index_output_copy, | |
| + sampled_token_ids=valid_sampled_token_ids, | |
| + logprobs=logprobs_lists, | |
| + prompt_logprobs_dict={}, | |
| + pooler_output=[], | |
| + kv_connector_output=None, | |
| + ec_connector_output=None, | |
| + num_nans_in_logits=num_nans_in_logits, | |
| + ) | |
| + | |
| + def initial_step(self, no_override_input=False): | |
| + self._update_states() | |
| + self._prepare_inputs(no_override_input) | |
| + input_ids = self.input_ids.gpu[:self.num_input_tokens] | |
| + self.positions = self.model_runner.positions.gpu[:self.num_input_tokens] | |
| + | |
| + self.attn_metadata, _ = self.model_runner._build_attention_metadata( | |
| + total_num_scheduled_tokens=self.num_input_tokens, | |
| + max_num_scheduled_tokens=1, | |
| + num_reqs=self.num_reqs, | |
| + ubatch_slices=None, | |
| + logits_indices=self.logits_indices, | |
| + use_spec_decode=False, | |
| + scheduled_encoder_inputs=None, | |
| + cascade_attn_prefix_lens=0, | |
| + follow_mega=True, | |
| + ) | |
| + self.cudagraph_runtime_mode, self.batch_descriptor = self._dispatch_cudagraph() | |
| + | |
| + with ( | |
| + set_forward_context( | |
| + self.attn_metadata, | |
| + self.vllm_config, | |
| + num_tokens=self.num_input_tokens, | |
| + num_tokens_across_dp=0, | |
| + cudagraph_runtime_mode=self.cudagraph_runtime_mode, | |
| + batch_descriptor=self.batch_descriptor, | |
| + ubatch_slices=None, | |
| + ) | |
| + ): | |
| + # print(f"input_ids={input_ids}") | |
| + # print(f"positions={positions}") | |
| + model_output = self.model_runner._model_forward( | |
| + input_ids=input_ids, | |
| + positions=self.positions, | |
| + ) | |
| + hidden_states = model_output | |
| + sample_hidden_states = hidden_states[self.logits_indices] | |
| + logits = self.model_runner.model.compute_logits(sample_hidden_states) | |
| + | |
| + model_output: ModelRunnerOutput = self._sample(hidden_states, logits) | |
| + | |
| + # breakpoint() | |
| + self.model_output_list.append(model_output) | |
| + | |
| + def apply_delta(self): | |
| + """Apply delta for the following steps""" | |
| + if False: | |
| + req_data = self.scheduler_output.scheduled_cached_reqs | |
| + for i, req_id in enumerate(req_data.req_ids): | |
| + req_data.num_computed_tokens[i] += 1 | |
| + | |
| + self.positions += 1 | |
| + # input ids is already setup in _sample | |
| + # update attn_metadata | |
| + # This assumes all value are the same object | |
| + meta_obj = next(iter(self.attn_metadata.values())) | |
| + meta_obj.slot_mapping += 1 | |
| + meta_obj.max_seq_len += 1 | |
| + meta_obj.seq_lens += 1 | |
| + | |
| + def following_step(self): | |
| + self.apply_delta() | |
| + with ( | |
| + set_forward_context( | |
| + self.attn_metadata, | |
| + self.vllm_config, | |
| + num_tokens=self.num_input_tokens, | |
| + num_tokens_across_dp=0, | |
| + cudagraph_runtime_mode=self.cudagraph_runtime_mode, | |
| + batch_descriptor=self.batch_descriptor, | |
| + ubatch_slices=None, | |
| + ) | |
| + ): | |
| + # print(f"input_ids={input_ids}") | |
| + # print(f"positions={positions}") | |
| + model_output = self.model_runner._model_forward( | |
| + input_ids=self.input_ids.gpu[:self.num_input_tokens], | |
| + positions=self.positions, | |
| + ) | |
| + hidden_states = model_output | |
| + sample_hidden_states = hidden_states[self.logits_indices] | |
| + logits = self.model_runner.model.compute_logits(sample_hidden_states) | |
| + | |
| + model_output: ModelRunnerOutput = self._sample(hidden_states, logits) | |
| + | |
| + self.model_output_list.append(model_output) | |
| + | |
| + | |
| + def construct_output(self): | |
| + assert len(self.model_output_list) > 0 | |
| + output = self.model_output_list[0] | |
| + sampled_token_ids = output.sampled_token_ids | |
| + for other in self.model_output_list[1:]: | |
| + other_token_ids = other.sampled_token_ids | |
| + for lhs, rhs in zip(sampled_token_ids, other_token_ids): | |
| + lhs.extend(rhs) | |
| + return output | |
| + | |
| + def run(self): | |
| + self.initial_step() | |
| + for offset in range(1, self.scheduler.scheduler_config.mega_step_size): | |
| + self.following_step() | |
| + | |
| + return self.construct_output() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment