Skip to content

Instantly share code, notes, and snippets.

@atineoSE
Last active May 12, 2025 15:50
Show Gist options
  • Select an option

  • Save atineoSE/1aad7b3d9a2b3e50bf23b73ad0f24575 to your computer and use it in GitHub Desktop.

Select an option

Save atineoSE/1aad7b3d9a2b3e50bf23b73ad0f24575 to your computer and use it in GitHub Desktop.
Inferencing with streaming, dynamic batching or static batching (draft)
import asyncio
import logging
import os
import warnings
from math import ceil
from typing import Any
from dotenv import load_dotenv
from huggingface_hub import login
from openai import AsyncOpenAI
from tqdm import tqdm
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
CACHE_DIR = os.getenv("CACHE_DIR")
if not HF_TOKEN:
warnings.warn(
"HF_TOKEN not found in environment variables. "
"Gated models and datasets will not be accessible."
)
class LLM:
hub_path: str
temperature: float
max_length_output: int
host: str | None = None
port: str
model: Any
def __init__(
self,
hub_path: str,
temperature: float,
max_length_output: int,
host: str | None,
port: str,
num_gpus: int | None = None,
):
self.hub_path = hub_path
self.temperature = temperature
self.max_length_output = max_length_output
if HF_TOKEN:
login(token=HF_TOKEN)
if host:
self.host = host
self.port = port
logging.debug(f"LLM: using model being served at {host}:{port}")
# Use the OpenAI-compatible model running as a separate process
self.model = AsyncOpenAI(
api_key="N/A",
base_url=f"http://{self.host}:{self.port}/v1",
)
return
logging.debug(f"LLM: loading model:{hub_path}")
from vllm import LLM as vLLM
# Manage own instance of vLLM if not using stand-alone inference engine
self.model = vLLM(
model=hub_path,
tensor_parallel_size=num_gpus if num_gpus else 1,
download_dir=CACHE_DIR,
)
def _infer_locally(self, batch: list[str]) -> list[str]:
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature=self.temperature,
max_tokens=self.max_length_output,
)
generations = self.model.generate(batch, sampling_params)
return [g.outputs[0].text for g in generations]
async def _post_streaming_request(self, prompt: str) -> str:
stream = await self.model.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=self.hub_path,
temperature=self.temperature,
max_tokens=self.max_length_output,
stream=True,
)
idx = 0
output = ""
async for chunk in stream:
if (content := chunk.choices[0].delta.content) is not None:
if idx == 0:
# You can measure here the time to first token (latency)
pass
output += content
idx += 1
return output
async def _post_single_request(self, prompt: str) -> str:
response = await self.model.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=self.hub_path,
temperature=self.temperature,
max_tokens=self.max_length_output,
)
return response.choices[0].message.content
async def _post_batched_requests(self, batch: list[str]) -> list[str]:
requests = map(
lambda prompt: self._post_single_request(prompt),
batch,
)
outputs = await asyncio.gather(*requests)
return outputs
def _batches(self, prompts: list[str], batch_size: int) -> list[list[str]]:
batches = []
for i in range(0, len(prompts), batch_size):
batches.append(prompts[i : i + batch_size])
return batches
def _split_outputs(self, batch: list[str], outputs: list[str]) -> list[str]:
split_outputs = []
for i, prompt in enumerate(batch):
splits = outputs[i].split(prompt)
split_outputs.append(splits[-1])
return split_outputs
async def perform_inference(self, batch_size: int) -> list[str]:
logging.debug(f"LLM: running inference for batch size {batch_size}")
prompts = ["PROMPT1", "PROMPT2, ..."] # Replace with actual prompts
num_batches = ceil(len(prompts) / batch_size)
logging.debug(
f"LLM: processing {len(prompts)} prompts in {num_batches} batches (batch size = {batch_size})"
)
split_outputs = []
for batch in tqdm(self._batches(prompts, batch_size)):
if self.host:
if batch_size == 1:
# Use streaming
outputs = [await self._post_streaming_request(batch[0])]
else:
# Use async client, sending requests in a batch concurrently
outputs = await self._post_batched_requests(batch)
else:
# Use static batching
outputs = self._infer_locally(batch)
split_outputs += self._split_outputs(batch, outputs)
return split_outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment