Skip to content

Instantly share code, notes, and snippets.

@ParagEkbote
Last active February 23, 2026 09:35
Show Gist options
  • Select an option

  • Save ParagEkbote/208d42e5992928795c690b09bfc9bd62 to your computer and use it in GitHub Desktop.

Select an option

Save ParagEkbote/208d42e5992928795c690b09bfc9bd62 to your computer and use it in GitHub Desktop.
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TorchAoConfig,
)
from torchao.quantization import Int8WeightOnlyConfig
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_ID = "HuggingFaceTB/SmolLM3-3B"
RUNS = 5
WARMUP = 1
GEN_TOKENS = 275 # updated
BATCH_SIZE = 3 # updated
PROMPT = "Paris, the capital of France is " * 50
FA2 = "kernels-community/flash-attn2"
os.makedirs("benchmark_plots", exist_ok=True)
# ---------------------------------------------------------------------------
# Tokenizer
# ---------------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
def prepare_inputs(device):
return tokenizer(
[PROMPT] * BATCH_SIZE,
return_tensors="pt",
padding=True,
).to(device)
# ---------------------------------------------------------------------------
# Benchmark helpers
# ---------------------------------------------------------------------------
def stats(arr: np.ndarray) -> dict:
mean = float(np.mean(arr))
std = float(np.std(arr))
return {
"mean": mean,
"std": std,
"p95": float(np.percentile(arr, 95)),
"cv": float(std / mean) if mean > 0 else 0.0, # coefficient of variation
}
def benchmark_prefill(model, name: str) -> np.ndarray:
device = next(model.parameters()).device
inputs = prepare_inputs(device)
latencies = []
model.eval()
with torch.no_grad():
for _ in range(WARMUP):
model(**inputs, use_cache=False)
torch.cuda.synchronize()
for _ in tqdm(range(RUNS), desc=f"Prefill {name}"):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
t0 = time.perf_counter()
with torch.no_grad():
model(**inputs, use_cache=False)
torch.cuda.synchronize()
latencies.append(time.perf_counter() - t0)
return np.array(latencies)
def benchmark_decode(model, name: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
device = next(model.parameters()).device
inputs = prepare_inputs(device)
latencies = []
throughputs = []
memories = []
model.eval()
# Warmup with the same shape used during benchmarking (critical for compile).
with torch.no_grad():
for _ in range(WARMUP):
model.generate(
**inputs,
max_new_tokens=GEN_TOKENS,
do_sample=False,
use_cache=True,
cache_implementation="static",
)
torch.cuda.synchronize()
for _ in tqdm(range(RUNS), desc=f"Decode {name}"):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
t0 = time.perf_counter()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=GEN_TOKENS,
do_sample=False,
use_cache=True,
cache_implementation="static",
)
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
gen_len = outputs.shape[1] - inputs["input_ids"].shape[1]
latencies.append(elapsed)
throughputs.append((gen_len * BATCH_SIZE) / elapsed)
memories.append(torch.cuda.max_memory_allocated() / 1024**2)
return np.array(latencies), np.array(throughputs), np.array(memories)
# ---------------------------------------------------------------------------
# Model loaders
# ---------------------------------------------------------------------------
def load_eager_bf16():
"""Baseline: BF16 eager attention — no fused kernel, no compile."""
return AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="cuda",
attn_implementation="eager",
)
def load_fa2_bf16():
"""FA2 hub kernel only — isolates the pure kernel contribution."""
return AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="cuda",
attn_implementation=FA2,
)
def load_fa2_compile():
"""FA2 + torch.compile — measures compile gains on top of the kernel."""
return torch.compile(
load_fa2_bf16(),
mode="reduce-overhead",
fullgraph=False,
)
def load_torchao():
"""FA2 + TorchAO Int8 weight-only quantisation, no compile."""
quant_config = TorchAoConfig(quant_type=Int8WeightOnlyConfig(group_size=None))
return AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cuda",
quantization_config=quant_config,
attn_implementation=FA2,
)
def load_bnb():
"""FA2 + BitsAndBytes 8-bit quantisation."""
return AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cuda",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
attn_implementation=FA2,
)
def load_torchao_compile():
"""FA2 + TorchAO + torch.compile — full optimisation stack."""
return torch.compile(
load_torchao(),
mode="reduce-overhead",
fullgraph=False,
)
# ---------------------------------------------------------------------------
# Run benchmarks
# ---------------------------------------------------------------------------
VARIANTS = {
"Eager BF16": load_eager_bf16,
"FA2": load_fa2_bf16,
"FA2+Compile": load_fa2_compile,
"FA2+TorchAO": load_torchao,
"FA2+BnB": load_bnb,
"FA2+TAO+Compile": load_torchao_compile,
}
results = {}
for name, loader in VARIANTS.items():
print(f"\n{'=' * 60}")
print(f"Loading Variant: {name}")
print(f"{'=' * 60}")
torch.cuda.empty_cache()
torch.cuda.synchronize()
model = loader()
# Optional: footprint diagnostic
print(f"Initial allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
print(f"Initial reserved : {torch.cuda.memory_reserved()/1024**3:.2f} GB")
prefill_lat = benchmark_prefill(model, name)
decode_lat, decode_tps, decode_mem = benchmark_decode(model, name)
results[name] = {
"prefill": stats(prefill_lat),
"decode": stats(decode_lat),
"throughput": stats(decode_tps),
"memory": stats(decode_mem),
}
# 🔴 CRITICAL: fully release model before next load
del model
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(f"Freed memory. Current allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
# ---------------------------------------------------------------------------
# Plot configuration
# ---------------------------------------------------------------------------
PALETTE = ["#A8DADC", "#457B9D", "#1D3557", "#E63946", "#F4A261", "#06D6A0"]
labels = list(results.keys())
x = np.arange(len(labels))
colors = PALETTE[: len(labels)]
short_labels = [l.replace("+", "+\n").replace("Compile", "Cmp") for l in labels]
def _bar(ax, means, stds, ylabel: str, title: str) -> None:
bars = ax.bar(
x, means,
yerr=stds,
capsize=5,
color=colors,
edgecolor="white",
linewidth=0.6,
error_kw={"elinewidth": 1.5, "ecolor": "#555"},
)
ax.set_xticks(x)
ax.set_xticklabels(short_labels, fontsize=8)
ax.set_ylabel(ylabel, fontsize=9)
ax.set_title(title, fontsize=10, fontweight="bold", pad=8)
ax.spines[["top", "right"]].set_visible(False)
ax.yaxis.grid(True, linestyle="--", alpha=0.4)
ax.set_axisbelow(True)
for bar, val in zip(bars, means):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() * 1.01,
f"{val:.1f}",
ha="center", va="bottom", fontsize=7.5, fontweight="bold",
)
# ---------------------------------------------------------------------------
# Figure 1 — 2×2 core metrics grid
# ---------------------------------------------------------------------------
fig, axes = plt.subplots(2, 2, figsize=(14, 9))
fig.suptitle(
f"LLM Inference Benchmark — {MODEL_ID.split('/')[-1]}\n"
f"batch={BATCH_SIZE} gen_tokens={GEN_TOKENS} runs={RUNS}",
fontsize=12, fontweight="bold",
)
_bar(
axes[0, 0],
[results[k]["prefill"]["mean"] for k in labels],
[results[k]["prefill"]["std"] for k in labels],
"Latency (s)", "Prefill Latency (lower is better)",
)
_bar(
axes[0, 1],
[results[k]["throughput"]["mean"] for k in labels],
[results[k]["throughput"]["std"] for k in labels],
"Tokens / sec", "Decode Throughput (higher is better)",
)
_bar(
axes[1, 0],
[results[k]["memory"]["mean"] for k in labels],
[results[k]["memory"]["std"] for k in labels],
"Peak VRAM (MB)", "Peak GPU Memory (lower is better)",
)
# Coefficient of variation — lower means more predictable latency.
cv_vals = [results[k]["decode"]["cv"] * 100 for k in labels]
ax = axes[1, 1]
bars = ax.bar(x, cv_vals, color=colors, edgecolor="white", linewidth=0.6)
ax.set_xticks(x)
ax.set_xticklabels(short_labels, fontsize=8)
ax.set_ylabel("CV (std / mean × 100)", fontsize=9)
ax.set_title("Decode Stability — CV (lower is better)", fontsize=10, fontweight="bold", pad=8)
ax.spines[["top", "right"]].set_visible(False)
ax.yaxis.grid(True, linestyle="--", alpha=0.4)
ax.set_axisbelow(True)
for bar, val in zip(bars, cv_vals):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() * 1.01,
f"{val:.1f}%",
ha="center", va="bottom", fontsize=7.5, fontweight="bold",
)
plt.tight_layout()
plt.savefig("benchmark_plots/core_metrics.png", dpi=300, bbox_inches="tight")
plt.close()
print("Saved: benchmark_plots/core_metrics.png")
# ---------------------------------------------------------------------------
# Figure 2 — efficiency scatter
# Axes: throughput (y) vs memory (x); bubble size ∝ prefill latency.
# Ideal variants sit top-left with small bubbles.
# ---------------------------------------------------------------------------
fig, ax = plt.subplots(figsize=(9, 6))
tps_means = [results[k]["throughput"]["mean"] for k in labels]
mem_means = [results[k]["memory"]["mean"] for k in labels]
pref_means = [results[k]["prefill"]["mean"] for k in labels]
pref_arr = np.array(pref_means)
pref_range = (pref_arr.max() - pref_arr.min()) + 1e-9
bubble_sizes = 150 + 1400 * (pref_arr - pref_arr.min()) / pref_range
ax.scatter(
mem_means, tps_means,
s=bubble_sizes, c=colors,
edgecolors="white", linewidths=1.2, alpha=0.85, zorder=3,
)
for i, name in enumerate(labels):
ax.annotate(
name, (mem_means[i], tps_means[i]),
textcoords="offset points", xytext=(9, 4),
fontsize=8, fontweight="bold",
)
ax.set_xlabel("Peak GPU Memory (MB) ← lower is better", fontsize=10)
ax.set_ylabel("Decode Throughput (tok/s) ↑ higher is better", fontsize=10)
ax.set_title(
"Efficiency Frontier: Throughput vs Memory\n"
"(bubble size ∝ prefill latency — smaller = faster to first token)",
fontsize=11, fontweight="bold",
)
ax.spines[["top", "right"]].set_visible(False)
ax.yaxis.grid(True, linestyle="--", alpha=0.35)
ax.xaxis.grid(True, linestyle="--", alpha=0.35)
ax.set_axisbelow(True)
plt.tight_layout()
plt.savefig("benchmark_plots/efficiency_scatter.png", dpi=300, bbox_inches="tight")
plt.close()
print("Saved: benchmark_plots/efficiency_scatter.png")
# ---------------------------------------------------------------------------
# Figure 3 — normalised radar
# All four metrics normalised 0–1; outer edge = best in that category.
# ---------------------------------------------------------------------------
def normalise(vals, higher_better: bool = True) -> np.ndarray:
arr = np.array(vals, dtype=float)
n = (arr - arr.min()) / ((arr.max() - arr.min()) + 1e-9)
return n if higher_better else 1 - n
categories = ["Throughput\n(↑)", "Prefill\nSpeed (↑)", "Memory\nEffic. (↑)", "Stability\n(↑)"]
N = len(categories)
angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
angles += angles[:1] # close the polygon
tps_n = normalise([results[k]["throughput"]["mean"] for k in labels], higher_better=True)
pref_n = normalise([results[k]["prefill"]["mean"] for k in labels], higher_better=False)
mem_n = normalise([results[k]["memory"]["mean"] for k in labels], higher_better=False)
stab_n = normalise([results[k]["decode"]["cv"] for k in labels], higher_better=False)
fig, ax = plt.subplots(figsize=(7, 7), subplot_kw={"polar": True})
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
ax.set_thetagrids(np.degrees(angles[:-1]), categories, fontsize=9)
ax.set_ylim(0, 1)
ax.set_yticks([0.25, 0.5, 0.75, 1.0])
ax.set_yticklabels(["0.25", "0.50", "0.75", "1.0"], fontsize=7, alpha=0.5)
ax.yaxis.grid(True, linestyle="--", alpha=0.3)
ax.spines["polar"].set_visible(False)
for i, name in enumerate(labels):
vals = [tps_n[i], pref_n[i], mem_n[i], stab_n[i]]
vals += vals[:1]
ax.plot(angles, vals, color=colors[i], linewidth=2, label=name)
ax.fill(angles, vals, color=colors[i], alpha=0.08)
ax.set_title(
"Normalised Performance Radar\n(outer edge = best in category)",
fontsize=11, fontweight="bold", pad=20,
)
ax.legend(loc="upper right", bbox_to_anchor=(1.4, 1.15), fontsize=8)
plt.tight_layout()
plt.savefig("benchmark_plots/radar.png", dpi=300, bbox_inches="tight")
plt.close()
print("Saved: benchmark_plots/radar.png")
print("\nAll plots saved to benchmark_plots/")
transformers
torch
numpy
torchao
matplotlib
accelerate
kernels
bitsandbytes
huggingface_hub
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment