Skip to content

Instantly share code, notes, and snippets.

@banach-space
Last active December 10, 2025 10:17
Show Gist options
  • Select an option

  • Save banach-space/71e2b82dd9eff13e0cbcd1ab68a0ce31 to your computer and use it in GitHub Desktop.

Select an option

Save banach-space/71e2b82dd9eff13e0cbcd1ab68a0ce31 to your computer and use it in GitHub Desktop.
#! /usr/bin/env bash
# === benchmark_iree_on_arm.sh ================================================
#
# Benchmarks models from https://huggingface.co/mlir-for-sve/sve-models/.
#
# NOTE - This script will not clone the models repo!
# NOTE - Set IREE_BUILD_DIR and SVE_MODELS_DIR accordingly!
#
# =============================================================================
set -euo pipefail
# SET THESE UP!
readonly IREE_BUILD_DIR=$HOME/iree-build/Release/
readonly SVE_MODELS_DIR=$HOME/sve-models
FORMAT_STR_GREY='\e[1m%-6s\e[0m'
FORMAT_STR_GREEN='\e[1;32m%-6s\e[0m'
FORMAT_STR='\e[1;37m%-6s\e[0m'
# === build_and_run ============================================================
#
# Compiles and benchmarks a model
#
# $1 - Model to benchmark (no suffix)
# $2 - Whether to enable distribution
# $3 - Whether to enable data-tiling
# $4 - Input arguments for iree-benchmark-module
# =============================================================================
build_and_run()
{
local -r model="$1"
local -r distribution="$2"
local -r dt_on="$3"
local -r pipeline_input="$4"
printf "${FORMAT_STR_GREY}%s\n" "BENCHMARKING: $model"
printf "${FORMAT_STR}%s" "DT:"
printf "${FORMAT_STR_GREEN}%s\n" "${dt_on}"
printf "${FORMAT_STR}%s" "Distribution:"
printf "${FORMAT_STR_GREEN}%s\n" "${distribution}"
if [[ $distribution == "false" ]]; then
local -r device="local-sync"
local -r disable_distr="true"
else
local -r device="local-task"
local -r disable_distr="false"
fi
$IREE_BUILD_DIR/tools/iree-compile \
--iree-llvmcpu-disable-distribution="${disable_distr}" \
--iree-global-opt-data-tiling="${dt_on}"\
--iree-hal-target-backends=llvm-cpu\
--iree-llvmcpu-target-cpu=generic \
--iree-llvmcpu-vector-pproc-strategy=peel \
--iree-llvmcpu-enable-ukernels=none\
--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-convert-conv2d-to-img2col))' \
$SVE_MODELS_DIR/${model}.mlir \
-o "${model}.vmfb"
$IREE_BUILD_DIR/tools/iree-benchmark-module\
--device="${device}"\
--module="${model}".vmfb\
--function=main\
${pipeline_input}
}
# === main =========================================================
#
# Entry point for this script. Iterates overall all models in sve-models and
# benchmarks all combinations of settings for data-tiling + distribution.
# ==================================================================
main()
{
local model distribution dt pipeline_input
# VIT
model=vit-base-patch16-224
pipeline_input=$(cat <<-END
--input="1x3x224x224xf32=1.0"
END
)
distribution=false
dt=false
build_and_run $model $distribution $dt "$pipeline_input"
distribution=true
dt=false
build_and_run $model $distribution $dt "$pipeline_input"
distribution=false
dt=true
build_and_run $model $distribution $dt "$pipeline_input"
distribution=true
dt=true
build_and_run $model $distribution $dt "$pipeline_input"
# Smoll
model=SmolLM135M-F32
pipeline_input=$(cat <<-END
--input="1x145xi32=2.0"\
--input="1x145xi32=4.0"
END
)
distribution=false
dt=false
build_and_run $model $distribution $dt "$pipeline_input"
distribution=true
dt=false
build_and_run $model $distribution $dt "$pipeline_input"
distribution=false
dt=true
build_and_run $model $distribution $dt "$pipeline_input"
distribution=true
dt=true
build_and_run $model $distribution $dt "$pipeline_input"
}
main "$@"
"""pytorch_benchmark_model - Benchmark PyTorch models.
Follows guide at: https://pytorch.org/tutorials/recipes/recipes/benchmark.html
"""
from typing import Any
import argparse
import torch
from torch.utils import benchmark
from torch import nn
from transformers import AutoConfig
from transformers import AutoModelForImageClassification
from transformers import AutoImageProcessor
torch.manual_seed(11)
#------------------------------------------------------------------------------
# Wrapper for VIT
#------------------------------------------------------------------------------
class VITWrapper(nn.Module):
def __init__(self, img_size: tuple[int, int] = (224, 224)):
super().__init__()
model_name = "google/vit-base-patch16-224"
cfg = AutoConfig.from_pretrained(model_name)
self.model = AutoModelForImageClassification.from_config(cfg)
self.model.eval()
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.img_size = img_size
# Compile with TorchInductor
# backend="inductor" is default, but we can be explicit
self.compiled_model = torch.compile(
self.model,
backend="inductor", # explicit inductor
mode="default", # or "max-autotune", "reduce-overhead", etc.
)
# --- generate a fixed example input once ---
torch.manual_seed(123) # choose any constant
img = torch.randn(1, 3, *self.img_size).clamp_(0, 1)
self._example_processed = self.processor(images=img, return_tensors="pt")
def example_inputs(self) -> dict[str, torch.Tensor]:
# return clones so callers don't mutate the cached tensors
return {k: v.clone() for k, v in self._example_processed.items()}
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
logits = self.model(pixel_values).logits
res = torch.argmax(logits, dim=-1)[0]
return res
class VITModelEager:
def __init__(self, VW):
super().__init__()
self.model = VW.model
self.inputs = VW.example_inputs()
# print(self.inputs)
def run(self):
with torch.no_grad():
logits = self.model(self.inputs["pixel_values"]).logits
res = torch.argmax(logits[0], dim=-1)
# print(f"EAGER: {res}")
class VITModelInductor:
def __init__(self, VW):
super().__init__()
eager = VW.model
self.inputs = VW.example_inputs()
# print(self.inputs)
# compile with TorchInductor
compiled = torch.compile(
eager,
backend="inductor", # explicit (default on CPU/GPU anyway)
mode="default", # or "max-autotune", etc.
)
self.model = compiled
# optional warmup so compile cost doesn’t pollute timing
with torch.no_grad():
for _ in range(3):
self.model(self.inputs["pixel_values"])
def run(self):
with torch.no_grad():
logits = self.model(self.inputs["pixel_values"]).logits
res = torch.argmax(logits, dim=-1)[0]
# print(f"IND: {res}")
import torch
from executorch.exir import to_edge_transform_and_lower
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
class VITModelExecuTorch:
def __init__(self, VW):
super().__init__()
eager = VW.model
self.inputs = VW.example_inputs()
# print(self.inputs)
# 1) Export to PyTorch Export IR
exported_program = torch.export.export(
eager,
(self.inputs["pixel_values"],),
strict=True,
)
# 2) Convert to ExecuTorch edge program
program = to_edge_transform_and_lower(
exported_program,
partitioner=[XnnpackPartitioner()] # CPU | CoreMLPartitioner() for iOS | QnnPartitioner() for Qualcomm
).to_executorch()
# 3. Save for deployment
with open("model.pte", "wb") as f:
f.write(program.buffer)
# Test locally via ExecuTorch runtime's pybind API (optional)
from executorch.runtime import Runtime
runtime = Runtime.get()
self.method = runtime.load_program("model.pte").load_method("forward")
for _ in range(3):
self.run()
def run(self):
logits = self.method.execute([self.inputs["pixel_values"]])[0]
res = torch.argmax(logits, dim=-1)[0]
# print(f"ET: {res}")
#------------------------------------------------------------------------------
# main
#------------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-threads",
type=int,
default=0,
help="Number of threads to use for benchmarking",
)
args = parser.parse_args()
num_threads = args.num_threads if args.num_threads > 0 else torch.get_num_threads()
vw = VITWrapper()
eager_model = VITModelEager(vw)
inductor_model = VITModelInductor(vw)
executorch_model = VITModelExecuTorch(vw)
timers = []
t_eager = benchmark.Timer(
stmt="model.run()",
globals={"model": eager_model},
num_threads=num_threads,
label="VIT",
sub_label="eager",
description="eager",
)
timers.append(t_eager.blocked_autorange(min_run_time=2.0))
t_inductor = benchmark.Timer(
stmt="model.run()",
globals={"model": inductor_model},
num_threads=num_threads,
label="VIT",
sub_label="inductor",
description="torch.compile(inductor)",
)
timers.append(t_inductor.blocked_autorange(min_run_time=2.0))
t_executorch = benchmark.Timer(
stmt="model.run()",
globals={"model": executorch_model},
num_threads=num_threads,
label="VIT",
sub_label="executorch",
description="ExecuTorch(XNNPACK)",
)
timers.append(t_executorch.blocked_autorange(min_run_time=2.0))
# Pretty compare
compare = benchmark.Compare(timers)
compare.print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment