Skip to content

Instantly share code, notes, and snippets.

@qgallouedec
Created October 17, 2025 17:20
Show Gist options
  • Select an option

  • Save qgallouedec/40e3b77da1504e3aacd2421a838ab9d9 to your computer and use it in GitHub Desktop.

Select an option

Save qgallouedec/40e3b77da1504e3aacd2421a838ab9d9 to your computer and use it in GitHub Desktop.
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import time
from torch.nn.utils.rnn import pad_sequence
import datasets
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
DISPLAYED_SAMPLES = 3
if __name__ == "__main__":
# Parse args
parser = argparse.ArgumentParser()
parser.add_argument("--num-blocks", "-n", type=int, default=None)
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
parser.add_argument("--samples", type=int, default=500)
parser.add_argument("--max-new-tokens", type=int, default=32)
args = parser.parse_args()
# Prepare model
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
attn_implementation=args.attn,
device_map="cuda",
dtype=torch.bfloat16,
)
model = model.eval()
# Prepare tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
dataset = dataset.select(range(args.samples))
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
# Prepare generation config
generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
use_cuda_graph=False, # Not supported for simple version
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
num_blocks=args.num_blocks,
max_batch_tokens=args.max_batch_tokens,
)
# Warmup iterations
input_ids = [torch.tensor(xx, device="cuda") for xx in simple_batch_inputs[: min(5, args.samples)]]
attention_mask = [torch.ones_like(xx) for xx in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id, padding_side="left")
attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0, padding_side="left")
_ = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=generation_config,
)
input_ids = [torch.tensor(xx, device="cuda") for xx in simple_batch_inputs]
attention_mask = [torch.ones_like(xx) for xx in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id, padding_side="left")
attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0, padding_side="left")
# Actual batch generation
print("--- Running Generation Example ---")
torch.cuda.synchronize()
start_time = time.time()
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=generation_config,
)
torch.cuda.synchronize()
end_time = time.time()
print("Done with generation.")
completion_ids = outputs[:, input_ids.shape[1]:]
prompts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
completions = tokenizer.batch_decode(completion_ids, skip_special_tokens=True)
# Decode outputs
token_count = args.max_new_tokens * args.samples
for i, (prompt, completion) in enumerate(zip(prompts, completions)):
# Display sample if asked
if i < DISPLAYED_SAMPLES:
print("-" * 20)
print(f"Input: {prompt}")
print(f"Output: {completion}")
# Compute stats and maybe print them
gen_time = end_time - start_time
tok_per_sec = token_count / gen_time
print("-" * 20)
print("--- Finished Generation Example ---\n")
print(f"Generation took: {gen_time:.2f} seconds for {token_count} tokens. {tok_per_sec:.2f}tok/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment