Skip to content

Instantly share code, notes, and snippets.

@nascheme
Last active February 12, 2026 22:56
Show Gist options
  • Select an option

  • Save nascheme/e8edf295b7e57f971a58a6eb61736bae to your computer and use it in GitHub Desktop.

Select an option

Save nascheme/e8edf295b7e57f971a58a6eb61736bae to your computer and use it in GitHub Desktop.
Multi-threaded vllm benchmark
# Dual-engine multi-GPU threaded vLLM throughput benchmark.
#
# Architecture: two independent LLMEngine instances (one per GPU) fed from
# a shared tokenized-request queue, with a single tokenizer thread.
#
# Tokenizer Thread (CPU) Engine Thread 0 (cuda:0) Engine Thread 1 (cuda:1)
# input_processor.process() add_request (from queue) add_request (from queue)
# tokenized_queue.put(ecr) engine0.step() engine1.step()
# (continuous streaming) (continuous streaming)
#
# Each engine thread continuously pulls requests from the shared queue and
# steps the engine. The engine's internal scheduler handles batching.
# The shared queue naturally load-balances between engine threads because
# each step() releases the GIL for GPU work, giving the other thread
# time to pull from the queue.
#
# For TP=1/PP=1, parallel-state init is idempotent so both engines
# coexist in one process with zero vLLM source modifications.
#
# Uses vllm.benchmarks.datasets.RandomDataset for benchmark-comparable
# prompts, and reports tokens/sec metrics matching `vllm bench throughput`.
#
# Written with assistance from Claude Opus 4.6.
import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
import queue
import sys
import threading
import time
from contextlib import contextmanager
import torch
from vllm import EngineArgs, SamplingParams
from vllm.benchmarks.datasets import RandomDataset
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.executor.abstract import Executor
# ---------------------------------------------------------------------------
# Monkey-patch vllm.forward_context to use thread-local storage.
#
# The upstream module stores _forward_context as a plain module-level global,
# so two engine threads running model forward passes stomp on each other.
# We replace the three public accessors + the internal override context
# manager with thread-local equivalents, then fix up every module that
# already did `from vllm.forward_context import get_forward_context`.
# ---------------------------------------------------------------------------
import vllm.forward_context as _fc
_forward_ctx_tls = threading.local()
_orig_get = _fc.get_forward_context
_orig_is_available = _fc.is_forward_context_available
_orig_override = _fc.override_forward_context
def _tl_get_forward_context():
ctx = getattr(_forward_ctx_tls, "ctx", None)
assert ctx is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context."
)
return ctx
def _tl_is_forward_context_available():
return getattr(_forward_ctx_tls, "ctx", None) is not None
@contextmanager
def _tl_override_forward_context(forward_context):
prev = getattr(_forward_ctx_tls, "ctx", None)
_forward_ctx_tls.ctx = forward_context
try:
yield
finally:
_forward_ctx_tls.ctx = prev
_patches = [
("get_forward_context", _orig_get, _tl_get_forward_context),
(
"is_forward_context_available",
_orig_is_available,
_tl_is_forward_context_available,
),
(
"override_forward_context",
_orig_override,
_tl_override_forward_context,
),
]
for _mod in list(sys.modules.values()):
if _mod is None:
continue
for _attr, _orig, _new in _patches:
try:
if getattr(_mod, _attr, None) is _orig:
setattr(_mod, _attr, _new)
except (TypeError, AttributeError):
pass
# ---------------------------------------------------------------------------
MODEL = "HuggingFaceTB/SmolLM2-360M-Instruct"
NUM_GPUS = 2
NUM_REQUESTS = 1000
INPUT_LEN = RandomDataset.DEFAULT_INPUT_LEN # 1024
OUTPUT_LEN = RandomDataset.DEFAULT_OUTPUT_LEN # 128
# Max requests to pull from the shared queue per engine step. Keeps the
# pull small enough that both engine threads get a fair share, while still
# giving the scheduler enough requests to form efficient batches.
MAX_PULL_PER_STEP = 8
def create_engine(engine_args, device_index, usage_context):
"""Create an LLMEngine pinned to a specific GPU.
Replicates LLMEngine.from_engine_args() but patches device_config.device
after create_engine_config(), because DeviceConfig.__post_init__ normalizes
torch.device("cuda:N") to torch.device("cuda"), stripping the index.
"""
# Clear the module-level RotaryEmbedding cache so this engine gets fresh
# instances on its own GPU rather than reusing another GPU's tensors.
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
_ROPE_DICT.clear()
vllm_config = engine_args.create_engine_config(usage_context)
vllm_config.device_config.device = torch.device(f"cuda:{device_index}")
executor_class = Executor.get_class(vllm_config)
return LLMEngine(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
usage_context=usage_context,
multiprocess_mode=False,
)
def tokenizer_worker(
input_processor, requests, sampling_params, tokenized_queue, done_event
):
"""Background thread: tokenize SampleRequests into EngineCoreRequests."""
for i, req in enumerate(requests):
ecr = input_processor.process_inputs(
str(i),
req.prompt,
sampling_params,
arrival_time=time.time(),
)
tokenized_queue.put((ecr, req.prompt))
done_event.set()
def engine_worker(
engine, device_index, tokenized_queue, sampling_params, tok_done, stats
):
"""Engine thread: continuously pull requests and step the engine.
Instead of collecting a full batch and processing it to completion,
this streams requests into the engine a few at a time. The engine's
internal scheduler forms optimal GPU batches. Between step() calls
the GIL is released for CUDA work, letting the other engine thread
pull from the shared queue — naturally balancing load.
"""
torch.cuda.set_device(device_index)
while True:
# Pull a limited number of requests from the shared queue.
for _ in range(MAX_PULL_PER_STEP):
try:
ecr, prompt_text = tokenized_queue.get_nowait()
engine.add_request(
ecr.request_id,
ecr,
sampling_params,
prompt_text=prompt_text,
)
except queue.Empty:
break
if engine.has_unfinished_requests():
request_outputs = engine.step()
for output in request_outputs:
if output.finished:
stats[0] += 1
if output.prompt_token_ids:
stats[1] += len(output.prompt_token_ids)
stats[2] += sum(
len(o.token_ids) for o in output.outputs if o
)
elif tok_done.is_set() and tokenized_queue.empty():
break
else:
# No work yet — block briefly for the tokenizer to catch up.
try:
ecr, prompt_text = tokenized_queue.get(timeout=0.5)
engine.add_request(
ecr.request_id,
ecr,
sampling_params,
prompt_text=prompt_text,
)
except queue.Empty:
if tok_done.is_set():
break
def main():
# 1. Generate random dataset (CPU, before engine creation).
print(
f"Generating {NUM_REQUESTS} random requests "
f"(input={INPUT_LEN}, output={OUTPUT_LEN}) ..."
)
tokenizer = get_tokenizer(MODEL)
dataset = RandomDataset(dataset_path=None, random_seed=42)
requests = dataset.sample(
tokenizer=tokenizer,
num_requests=NUM_REQUESTS,
input_len=INPUT_LEN,
output_len=OUTPUT_LEN,
)
print(f"Dataset ready: {len(requests)} requests")
# 2. Create engines sequentially (parallel state init is idempotent
# for TP=1, PP=1 — second engine reuses the already-initialized state).
engine_args = EngineArgs(
model=MODEL,
enforce_eager=True,
gpu_memory_utilization=0.8,
async_scheduling=False,
)
engines = []
for i in range(NUM_GPUS):
print(f"Creating engine on cuda:{i} ...")
engines.append(create_engine(engine_args, i, UsageContext.LLM_CLASS))
print(f"All {NUM_GPUS} engines created.")
# Benchmark-compatible sampling params: fixed output length, ignore EOS.
sampling_params = SamplingParams(
n=1,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=OUTPUT_LEN,
)
# 3. Start tokenizer thread.
tokenized_queue = queue.Queue(maxsize=NUM_REQUESTS)
tok_done = threading.Event()
tok_thread = threading.Thread(
target=tokenizer_worker,
args=(
engines[0].input_processor,
requests,
sampling_params,
tokenized_queue,
tok_done,
),
name="LLM::tok",
)
tok_thread.start()
# 4. Start engine threads — one per GPU.
# stats per engine: [num_completed, prompt_tokens, output_tokens]
stats = [[0, 0, 0] for _ in range(NUM_GPUS)]
engine_threads = []
for i, engine in enumerate(engines):
t = threading.Thread(
target=engine_worker,
args=(
engine,
i,
tokenized_queue,
sampling_params,
tok_done,
stats[i],
),
name=f"LLM::engine{i}",
)
t.start()
engine_threads.append(t)
# 5. Wait for completion with progress updates.
start_time = time.time()
while any(t.is_alive() for t in engine_threads):
time.sleep(2)
elapsed = time.time() - start_time
total = sum(s[0] for s in stats)
print(f" [{elapsed:.1f}s] {total}/{NUM_REQUESTS} completed ...")
tok_thread.join()
for t in engine_threads:
t.join()
# 6. Report results (matches `vllm bench throughput` format).
elapsed = time.time() - start_time
total_completed = sum(s[0] for s in stats)
total_prompt_tokens = sum(s[1] for s in stats)
total_output_tokens = sum(s[2] for s in stats)
total_tokens = total_prompt_tokens + total_output_tokens
print(
f"\nThroughput: {total_completed / elapsed:.2f} requests/s, "
f"{total_tokens / elapsed:.2f} total tokens/s, "
f"{total_output_tokens / elapsed:.2f} output tokens/s"
)
print(f"Total num prompt tokens: {total_prompt_tokens}")
print(f"Total num output tokens: {total_output_tokens}")
for i, s in enumerate(stats):
print(
f" cuda:{i}: {s[0]} reqs, "
f"{s[1]} prompt toks, {s[2]} output toks"
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment