Last active
October 14, 2025 04:33
-
-
Save justinchuby/b817e87d8f4c9ae701ece75f2792b6de to your computer and use it in GitHub Desktop.
Export HF model to ONNX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Export to ONNX. | |
| transformers_version == "4.52.0" | |
| """ | |
| import onnx_diagnostic.tasks.text_generation | |
| import torch | |
| from transformers import AutoConfig, AutoModel | |
| import onnxscript | |
| import onnxscript.rewriter.onnx_fusions | |
| # MODEL_ID = "google/gemma-2b" | |
| # MODEL_ID = "google/gemma-3-270m-it" | |
| # MODEL_ID = "google/gemma-3-270m" | |
| MODEL_ID = "google/gemma-3-4b-it" | |
| def get_hf_model(model_id: str): | |
| config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa") | |
| # with torch.device('meta'): | |
| # model = AutoModel.from_config(config) | |
| # # This line is important. Some models may produce different | |
| # # outputs even with the same inputs in training mode. | |
| # model.eval() | |
| model = AutoModel.from_pretrained(model_id, config=config) | |
| return model, config | |
| def get_model_input(model: torch.nn.Module, config): | |
| # with torch.device('meta'): | |
| result = onnx_diagnostic.tasks.text_generation.get_inputs( | |
| model, | |
| config, | |
| dummy_max_token_id=10, | |
| num_hidden_layers=34, | |
| num_key_value_heads=4, | |
| head_dim=256, | |
| cls_cache="DynamicCache" | |
| ) | |
| return result["inputs"], result["dynamic_shapes"] | |
| model, config = get_hf_model(MODEL_ID) | |
| example_kwargs, dynamic_shapes = get_model_input(model, config) | |
| # ONNX Export | |
| onnx_program = torch.onnx.export( | |
| model, | |
| (), | |
| kwargs=example_kwargs, | |
| dynamic_shapes=dynamic_shapes, | |
| opset_version=23, | |
| dynamo=True, | |
| # profile=True, | |
| ) | |
| onnx_program.save("gemma3_4b_it.onnx") | |
| # onnxscript.rewriter.onnx_fusions.fuse(onnx_program.model) # This is performed as part of the optimize=True option |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
transformers==4.55.0