Last active
May 12, 2025 15:50
-
-
Save atineoSE/1aad7b3d9a2b3e50bf23b73ad0f24575 to your computer and use it in GitHub Desktop.
Inferencing with streaming, dynamic batching or static batching (draft)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| import logging | |
| import os | |
| import warnings | |
| from math import ceil | |
| from typing import Any | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login | |
| from openai import AsyncOpenAI | |
| from tqdm import tqdm | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| CACHE_DIR = os.getenv("CACHE_DIR") | |
| if not HF_TOKEN: | |
| warnings.warn( | |
| "HF_TOKEN not found in environment variables. " | |
| "Gated models and datasets will not be accessible." | |
| ) | |
| class LLM: | |
| hub_path: str | |
| temperature: float | |
| max_length_output: int | |
| host: str | None = None | |
| port: str | |
| model: Any | |
| def __init__( | |
| self, | |
| hub_path: str, | |
| temperature: float, | |
| max_length_output: int, | |
| host: str | None, | |
| port: str, | |
| num_gpus: int | None = None, | |
| ): | |
| self.hub_path = hub_path | |
| self.temperature = temperature | |
| self.max_length_output = max_length_output | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| if host: | |
| self.host = host | |
| self.port = port | |
| logging.debug(f"LLM: using model being served at {host}:{port}") | |
| # Use the OpenAI-compatible model running as a separate process | |
| self.model = AsyncOpenAI( | |
| api_key="N/A", | |
| base_url=f"http://{self.host}:{self.port}/v1", | |
| ) | |
| return | |
| logging.debug(f"LLM: loading model:{hub_path}") | |
| from vllm import LLM as vLLM | |
| # Manage own instance of vLLM if not using stand-alone inference engine | |
| self.model = vLLM( | |
| model=hub_path, | |
| tensor_parallel_size=num_gpus if num_gpus else 1, | |
| download_dir=CACHE_DIR, | |
| ) | |
| def _infer_locally(self, batch: list[str]) -> list[str]: | |
| from vllm import SamplingParams | |
| sampling_params = SamplingParams( | |
| temperature=self.temperature, | |
| max_tokens=self.max_length_output, | |
| ) | |
| generations = self.model.generate(batch, sampling_params) | |
| return [g.outputs[0].text for g in generations] | |
| async def _post_streaming_request(self, prompt: str) -> str: | |
| stream = await self.model.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model=self.hub_path, | |
| temperature=self.temperature, | |
| max_tokens=self.max_length_output, | |
| stream=True, | |
| ) | |
| idx = 0 | |
| output = "" | |
| async for chunk in stream: | |
| if (content := chunk.choices[0].delta.content) is not None: | |
| if idx == 0: | |
| # You can measure here the time to first token (latency) | |
| pass | |
| output += content | |
| idx += 1 | |
| return output | |
| async def _post_single_request(self, prompt: str) -> str: | |
| response = await self.model.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model=self.hub_path, | |
| temperature=self.temperature, | |
| max_tokens=self.max_length_output, | |
| ) | |
| return response.choices[0].message.content | |
| async def _post_batched_requests(self, batch: list[str]) -> list[str]: | |
| requests = map( | |
| lambda prompt: self._post_single_request(prompt), | |
| batch, | |
| ) | |
| outputs = await asyncio.gather(*requests) | |
| return outputs | |
| def _batches(self, prompts: list[str], batch_size: int) -> list[list[str]]: | |
| batches = [] | |
| for i in range(0, len(prompts), batch_size): | |
| batches.append(prompts[i : i + batch_size]) | |
| return batches | |
| def _split_outputs(self, batch: list[str], outputs: list[str]) -> list[str]: | |
| split_outputs = [] | |
| for i, prompt in enumerate(batch): | |
| splits = outputs[i].split(prompt) | |
| split_outputs.append(splits[-1]) | |
| return split_outputs | |
| async def perform_inference(self, batch_size: int) -> list[str]: | |
| logging.debug(f"LLM: running inference for batch size {batch_size}") | |
| prompts = ["PROMPT1", "PROMPT2, ..."] # Replace with actual prompts | |
| num_batches = ceil(len(prompts) / batch_size) | |
| logging.debug( | |
| f"LLM: processing {len(prompts)} prompts in {num_batches} batches (batch size = {batch_size})" | |
| ) | |
| split_outputs = [] | |
| for batch in tqdm(self._batches(prompts, batch_size)): | |
| if self.host: | |
| if batch_size == 1: | |
| # Use streaming | |
| outputs = [await self._post_streaming_request(batch[0])] | |
| else: | |
| # Use async client, sending requests in a batch concurrently | |
| outputs = await self._post_batched_requests(batch) | |
| else: | |
| # Use static batching | |
| outputs = self._infer_locally(batch) | |
| split_outputs += self._split_outputs(batch, outputs) | |
| return split_outputs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment