Last active
December 10, 2025 10:17
-
-
Save banach-space/71e2b82dd9eff13e0cbcd1ab68a0ce31 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #! /usr/bin/env bash | |
| # === benchmark_iree_on_arm.sh ================================================ | |
| # | |
| # Benchmarks models from https://huggingface.co/mlir-for-sve/sve-models/. | |
| # | |
| # NOTE - This script will not clone the models repo! | |
| # NOTE - Set IREE_BUILD_DIR and SVE_MODELS_DIR accordingly! | |
| # | |
| # ============================================================================= | |
| set -euo pipefail | |
| # SET THESE UP! | |
| readonly IREE_BUILD_DIR=$HOME/iree-build/Release/ | |
| readonly SVE_MODELS_DIR=$HOME/sve-models | |
| FORMAT_STR_GREY='\e[1m%-6s\e[0m' | |
| FORMAT_STR_GREEN='\e[1;32m%-6s\e[0m' | |
| FORMAT_STR='\e[1;37m%-6s\e[0m' | |
| # === build_and_run ============================================================ | |
| # | |
| # Compiles and benchmarks a model | |
| # | |
| # $1 - Model to benchmark (no suffix) | |
| # $2 - Whether to enable distribution | |
| # $3 - Whether to enable data-tiling | |
| # $4 - Input arguments for iree-benchmark-module | |
| # ============================================================================= | |
| build_and_run() | |
| { | |
| local -r model="$1" | |
| local -r distribution="$2" | |
| local -r dt_on="$3" | |
| local -r pipeline_input="$4" | |
| printf "${FORMAT_STR_GREY}%s\n" "BENCHMARKING: $model" | |
| printf "${FORMAT_STR}%s" "DT:" | |
| printf "${FORMAT_STR_GREEN}%s\n" "${dt_on}" | |
| printf "${FORMAT_STR}%s" "Distribution:" | |
| printf "${FORMAT_STR_GREEN}%s\n" "${distribution}" | |
| if [[ $distribution == "false" ]]; then | |
| local -r device="local-sync" | |
| local -r disable_distr="true" | |
| else | |
| local -r device="local-task" | |
| local -r disable_distr="false" | |
| fi | |
| $IREE_BUILD_DIR/tools/iree-compile \ | |
| --iree-llvmcpu-disable-distribution="${disable_distr}" \ | |
| --iree-global-opt-data-tiling="${dt_on}"\ | |
| --iree-hal-target-backends=llvm-cpu\ | |
| --iree-llvmcpu-target-cpu=generic \ | |
| --iree-llvmcpu-vector-pproc-strategy=peel \ | |
| --iree-llvmcpu-enable-ukernels=none\ | |
| --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-convert-conv2d-to-img2col))' \ | |
| $SVE_MODELS_DIR/${model}.mlir \ | |
| -o "${model}.vmfb" | |
| $IREE_BUILD_DIR/tools/iree-benchmark-module\ | |
| --device="${device}"\ | |
| --module="${model}".vmfb\ | |
| --function=main\ | |
| ${pipeline_input} | |
| } | |
| # === main ========================================================= | |
| # | |
| # Entry point for this script. Iterates overall all models in sve-models and | |
| # benchmarks all combinations of settings for data-tiling + distribution. | |
| # ================================================================== | |
| main() | |
| { | |
| local model distribution dt pipeline_input | |
| # VIT | |
| model=vit-base-patch16-224 | |
| pipeline_input=$(cat <<-END | |
| --input="1x3x224x224xf32=1.0" | |
| END | |
| ) | |
| distribution=false | |
| dt=false | |
| build_and_run $model $distribution $dt "$pipeline_input" | |
| distribution=true | |
| dt=false | |
| build_and_run $model $distribution $dt "$pipeline_input" | |
| distribution=false | |
| dt=true | |
| build_and_run $model $distribution $dt "$pipeline_input" | |
| distribution=true | |
| dt=true | |
| build_and_run $model $distribution $dt "$pipeline_input" | |
| # Smoll | |
| model=SmolLM135M-F32 | |
| pipeline_input=$(cat <<-END | |
| --input="1x145xi32=2.0"\ | |
| --input="1x145xi32=4.0" | |
| END | |
| ) | |
| distribution=false | |
| dt=false | |
| build_and_run $model $distribution $dt "$pipeline_input" | |
| distribution=true | |
| dt=false | |
| build_and_run $model $distribution $dt "$pipeline_input" | |
| distribution=false | |
| dt=true | |
| build_and_run $model $distribution $dt "$pipeline_input" | |
| distribution=true | |
| dt=true | |
| build_and_run $model $distribution $dt "$pipeline_input" | |
| } | |
| main "$@" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """pytorch_benchmark_model - Benchmark PyTorch models. | |
| Follows guide at: https://pytorch.org/tutorials/recipes/recipes/benchmark.html | |
| """ | |
| from typing import Any | |
| import argparse | |
| import torch | |
| from torch.utils import benchmark | |
| from torch import nn | |
| from transformers import AutoConfig | |
| from transformers import AutoModelForImageClassification | |
| from transformers import AutoImageProcessor | |
| torch.manual_seed(11) | |
| #------------------------------------------------------------------------------ | |
| # Wrapper for VIT | |
| #------------------------------------------------------------------------------ | |
| class VITWrapper(nn.Module): | |
| def __init__(self, img_size: tuple[int, int] = (224, 224)): | |
| super().__init__() | |
| model_name = "google/vit-base-patch16-224" | |
| cfg = AutoConfig.from_pretrained(model_name) | |
| self.model = AutoModelForImageClassification.from_config(cfg) | |
| self.model.eval() | |
| self.processor = AutoImageProcessor.from_pretrained(model_name) | |
| self.img_size = img_size | |
| # Compile with TorchInductor | |
| # backend="inductor" is default, but we can be explicit | |
| self.compiled_model = torch.compile( | |
| self.model, | |
| backend="inductor", # explicit inductor | |
| mode="default", # or "max-autotune", "reduce-overhead", etc. | |
| ) | |
| # --- generate a fixed example input once --- | |
| torch.manual_seed(123) # choose any constant | |
| img = torch.randn(1, 3, *self.img_size).clamp_(0, 1) | |
| self._example_processed = self.processor(images=img, return_tensors="pt") | |
| def example_inputs(self) -> dict[str, torch.Tensor]: | |
| # return clones so callers don't mutate the cached tensors | |
| return {k: v.clone() for k, v in self._example_processed.items()} | |
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: | |
| logits = self.model(pixel_values).logits | |
| res = torch.argmax(logits, dim=-1)[0] | |
| return res | |
| class VITModelEager: | |
| def __init__(self, VW): | |
| super().__init__() | |
| self.model = VW.model | |
| self.inputs = VW.example_inputs() | |
| # print(self.inputs) | |
| def run(self): | |
| with torch.no_grad(): | |
| logits = self.model(self.inputs["pixel_values"]).logits | |
| res = torch.argmax(logits[0], dim=-1) | |
| # print(f"EAGER: {res}") | |
| class VITModelInductor: | |
| def __init__(self, VW): | |
| super().__init__() | |
| eager = VW.model | |
| self.inputs = VW.example_inputs() | |
| # print(self.inputs) | |
| # compile with TorchInductor | |
| compiled = torch.compile( | |
| eager, | |
| backend="inductor", # explicit (default on CPU/GPU anyway) | |
| mode="default", # or "max-autotune", etc. | |
| ) | |
| self.model = compiled | |
| # optional warmup so compile cost doesn’t pollute timing | |
| with torch.no_grad(): | |
| for _ in range(3): | |
| self.model(self.inputs["pixel_values"]) | |
| def run(self): | |
| with torch.no_grad(): | |
| logits = self.model(self.inputs["pixel_values"]).logits | |
| res = torch.argmax(logits, dim=-1)[0] | |
| # print(f"IND: {res}") | |
| import torch | |
| from executorch.exir import to_edge_transform_and_lower | |
| from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner | |
| class VITModelExecuTorch: | |
| def __init__(self, VW): | |
| super().__init__() | |
| eager = VW.model | |
| self.inputs = VW.example_inputs() | |
| # print(self.inputs) | |
| # 1) Export to PyTorch Export IR | |
| exported_program = torch.export.export( | |
| eager, | |
| (self.inputs["pixel_values"],), | |
| strict=True, | |
| ) | |
| # 2) Convert to ExecuTorch edge program | |
| program = to_edge_transform_and_lower( | |
| exported_program, | |
| partitioner=[XnnpackPartitioner()] # CPU | CoreMLPartitioner() for iOS | QnnPartitioner() for Qualcomm | |
| ).to_executorch() | |
| # 3. Save for deployment | |
| with open("model.pte", "wb") as f: | |
| f.write(program.buffer) | |
| # Test locally via ExecuTorch runtime's pybind API (optional) | |
| from executorch.runtime import Runtime | |
| runtime = Runtime.get() | |
| self.method = runtime.load_program("model.pte").load_method("forward") | |
| for _ in range(3): | |
| self.run() | |
| def run(self): | |
| logits = self.method.execute([self.inputs["pixel_values"]])[0] | |
| res = torch.argmax(logits, dim=-1)[0] | |
| # print(f"ET: {res}") | |
| #------------------------------------------------------------------------------ | |
| # main | |
| #------------------------------------------------------------------------------ | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--num-threads", | |
| type=int, | |
| default=0, | |
| help="Number of threads to use for benchmarking", | |
| ) | |
| args = parser.parse_args() | |
| num_threads = args.num_threads if args.num_threads > 0 else torch.get_num_threads() | |
| vw = VITWrapper() | |
| eager_model = VITModelEager(vw) | |
| inductor_model = VITModelInductor(vw) | |
| executorch_model = VITModelExecuTorch(vw) | |
| timers = [] | |
| t_eager = benchmark.Timer( | |
| stmt="model.run()", | |
| globals={"model": eager_model}, | |
| num_threads=num_threads, | |
| label="VIT", | |
| sub_label="eager", | |
| description="eager", | |
| ) | |
| timers.append(t_eager.blocked_autorange(min_run_time=2.0)) | |
| t_inductor = benchmark.Timer( | |
| stmt="model.run()", | |
| globals={"model": inductor_model}, | |
| num_threads=num_threads, | |
| label="VIT", | |
| sub_label="inductor", | |
| description="torch.compile(inductor)", | |
| ) | |
| timers.append(t_inductor.blocked_autorange(min_run_time=2.0)) | |
| t_executorch = benchmark.Timer( | |
| stmt="model.run()", | |
| globals={"model": executorch_model}, | |
| num_threads=num_threads, | |
| label="VIT", | |
| sub_label="executorch", | |
| description="ExecuTorch(XNNPACK)", | |
| ) | |
| timers.append(t_executorch.blocked_autorange(min_run_time=2.0)) | |
| # Pretty compare | |
| compare = benchmark.Compare(timers) | |
| compare.print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment