Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created March 7, 2026 00:46
Show Gist options
  • Select an option

  • Save shunting314/3cd5ebb59ce0eb4cc4d1c131e4ba8349 to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/3cd5ebb59ce0eb4cc4d1c131e4ba8349 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from torch import distributed
import contextlib
import os
from vllm import LLM, SamplingParams
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
class script_args:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
profile = os.getenv("DO_PROFILE") == "1"
compile = True
if __name__ == "__main__":
if script_args.profile:
profile = torch.profiler.profile(with_stack=True)
else:
profile = contextlib.nullcontext()
if script_args.compile:
compilation_config = None
else:
from vllm.config import CompilationConfig, CUDAGraphMode, CompilationMode
compilation_config = CompilationConfig(cudagraph_mode=CUDAGraphMode.NONE, mode=CompilationMode.NONE)
llm = LLM(
model=script_args.model_name,
compilation_config=compilation_config,
attention_config={"backend": "FLEX_ATTENTION"},
)
sampling_params = SamplingParams(temperature=0, max_tokens=32 if script_args.profile else 128 * 4)
requests = [
"How to estimate the value of pi in mathematics?",
"Show me how quick-sort works.",
"Can you explain FFT to me?",
]
if script_args.profile:
# do a warmup if profiling
outputs = llm.generate(requests, sampling_params)
with profile:
outputs = llm.generate(requests, sampling_params)
assert len(outputs) == len(requests)
for i, req_text in enumerate(requests):
print(f"Response for request {i}: {outputs[i].outputs[0].text}")
if script_args.profile:
path = "/tmp/profile.json"
profile.export_chrome_trace(path)
print(f"Profile written to {path}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment