Last active
October 14, 2025 04:33
-
-
Save justinchuby/b817e87d8f4c9ae701ece75f2792b6de to your computer and use it in GitHub Desktop.
Export HF model to ONNX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Export to ONNX. | |
| transformers_version == "4.52.0" | |
| """ | |
| import onnx_diagnostic.tasks.text_generation | |
| import torch | |
| from transformers import AutoConfig, AutoModel | |
| import onnxscript | |
| import onnxscript.rewriter.onnx_fusions | |
| # MODEL_ID = "google/gemma-2b" | |
| # MODEL_ID = "google/gemma-3-270m-it" | |
| # MODEL_ID = "google/gemma-3-270m" | |
| MODEL_ID = "google/gemma-3-4b-it" | |
| def get_hf_model(model_id: str): | |
| config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa") | |
| # with torch.device('meta'): | |
| # model = AutoModel.from_config(config) | |
| # # This line is important. Some models may produce different | |
| # # outputs even with the same inputs in training mode. | |
| # model.eval() | |
| model = AutoModel.from_pretrained(model_id, config=config) | |
| return model, config | |
| def get_model_input(model: torch.nn.Module, config): | |
| # with torch.device('meta'): | |
| result = onnx_diagnostic.tasks.text_generation.get_inputs( | |
| model, | |
| config, | |
| dummy_max_token_id=10, | |
| num_hidden_layers=34, | |
| num_key_value_heads=4, | |
| head_dim=256, | |
| cls_cache="DynamicCache" | |
| ) | |
| return result["inputs"], result["dynamic_shapes"] | |
| model, config = get_hf_model(MODEL_ID) | |
| example_kwargs, dynamic_shapes = get_model_input(model, config) | |
| # ONNX Export | |
| onnx_program = torch.onnx.export( | |
| model, | |
| (), | |
| kwargs=example_kwargs, | |
| dynamic_shapes=dynamic_shapes, | |
| opset_version=23, | |
| dynamo=True, | |
| # profile=True, | |
| ) | |
| onnx_program.save("gemma3_4b_it.onnx") | |
| # onnxscript.rewriter.onnx_fusions.fuse(onnx_program.model) # This is performed as part of the optimize=True option |
Author
"""Export to ONNX.
transformers_version == "4.52.0"
"""
import onnx_diagnostic.tasks.text_generation
import torch
from transformers import AutoConfig, AutoModel
from torch._subclasses.fake_tensor import FakeTensorMode
import onnxscript
import onnxscript.rewriter.onnx_fusions
# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"
def get_hf_model(model_id: str):
config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
# with torch.device('meta'):
# model = AutoModel.from_config(config)
# # This line is important. Some models may produce different
# # outputs even with the same inputs in training mode.
# model.eval()
model = AutoModel.from_pretrained(model_id, config=config)
return model, config
def get_model_input(model: torch.nn.Module, config):
# with torch.device('meta'):
result = onnx_diagnostic.tasks.text_generation.get_inputs(
model,
config,
dummy_max_token_id=10,
num_hidden_layers=34,
num_key_value_heads=4,
head_dim=256,
cls_cache="DynamicCache"
)
return result["inputs"], result["dynamic_shapes"]
# with FakeTensorMode():
model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes = get_model_input(model, config)
# ONNX Export
onnx_program = torch.onnx.export(
model,
(),
kwargs=example_kwargs,
dynamic_shapes=dynamic_shapes,
opset_version=23,
dynamo=True,
report=True,
optimize=False,
# profile=True,
)
print("Saving ONNX model...")
onnx_program.save("gemma3_4b_it_original.onnx", include_initializers=False, keep_initializers_as_inputs=True)
print("Optimizing ONNX model...")
onnx_program.optimize()
print("Saving optimized ONNX model...")
onnx_program.save("gemma3_4b_it_optimized.onnx", include_initializers=False, keep_initializers_as_inputs=True)
print("Fusing ONNX model...")
onnx_program.model.functions.clear()
onnxscript.optimizer.optimize(onnx_program.model, num_iterations=2, input_size_limit=10000)
onnxscript.rewriter.onnx_fusions.fuse(onnx_program.model)
onnxscript.optimizer.optimize(onnx_program.model, num_iterations=2, input_size_limit=10000)
print("Saving fused ONNX model...")
onnx_program.save("gemma3_4b_it_fused.onnx", include_initializers=False, keep_initializers_as_inputs=True)
Author
transformers==4.56.0
import torch
from transformers import AutoConfig, AutoModel
import transformers
# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"
def get_hf_model(model_id: str):
config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
# with torch.device('meta'):
# model = AutoModel.from_config(config)
# # This line is important. Some models may produce different
# # outputs even with the same inputs in training mode.
# model.eval()
model = AutoModel.from_pretrained(model_id, config=config)
model = TextGenerationModelWrapper(model)
return model, config
def create_text_gen_example_inputs(
config, batch_size: int = 2, seq_len: int = 3, past_seq_len: int = 2
):
"""Create example inputs and dynamic axes for ONNX export."""
config = config.text_config
num_hidden_layers = config.num_hidden_layers
batch = "batch"
sequence_len = "sequence_len"
past_sequence_len = "past_sequence_len"
dynamic_shapes = {
"input_ids": {0: batch, 1: sequence_len},
"attention_mask": {
0: batch,
1: "past_sequence_len+sequence_len",
},
"position_ids": {
0: batch,
1: sequence_len,
},
"past_key_values": [
({0: batch, 2: past_sequence_len}, {0: batch, 2: past_sequence_len})
for _ in range(num_hidden_layers)
],
}
input_names = [
"input_ids",
"attention_mask",
"position_ids",
*[
name
for i in range(num_hidden_layers)
for name in (f"past_key_values.{i}.key", f"past_key_values.{i}.value")
],
]
output_names = [
"logits",
*[
name
for i in range(num_hidden_layers)
for name in (f"present.{i}.key", f"present.{i}.value")
],
]
num_key_value_heads = config.num_key_value_heads
head_dim = config.head_dim
example_inputs = dict(
input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
attention_mask=torch.ones(
(batch_size, past_seq_len + seq_len),
dtype=torch.int64,
),
position_ids=torch.arange(
past_seq_len,
past_seq_len + seq_len,
dtype=torch.int64,
).expand((batch_size, -1)),
past_key_values=[
(
torch.randn(
batch_size,
num_key_value_heads,
seq_len,
head_dim,
),
torch.randn(
batch_size,
num_key_value_heads,
seq_len,
head_dim,
),
)
for _ in range(num_hidden_layers)
],
)
return example_inputs, dynamic_shapes, input_names, output_names
def make_dynamic_cache(
key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
for i, (key, value) in enumerate(key_value_pairs):
cache.update(key, value, i)
return cache
class TextGenerationModelWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
def forward(
self,
input_ids,
attention_mask,
position_ids,
past_key_values: list[tuple[torch.Tensor, torch.Tensor]],
):
hf_output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=make_dynamic_cache(past_key_values),
)
return hf_output.last_hidden_state, hf_output.hidden_states
model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes, input_names, output_names = (
create_text_gen_example_inputs(config)
)
# ONNX Export
with torch._dynamo.config.patch(fake_tensor_cache_enabled=False):
onnx_program = torch.onnx.export(
model,
(),
kwargs=example_kwargs,
input_names=input_names,
output_names=output_names,
dynamic_shapes=dynamic_shapes,
opset_version=23,
dynamo=True,
report=True,
)
onnx_program.save("gemma3_4b_it.onnx", external_data=True)
onnx_output = onnx_program.call_reference(**example_kwargs)
print(onnx_output)
print("------------")
output = model(**example_kwargs)
print(output)
Author
import torch
from transformers import AutoConfig, AutoModel
import transformers
# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"
def get_hf_model(model_id: str):
config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
# with torch.device('meta'):
# model = AutoModel.from_config(config)
# # This line is important. Some models may produce different
# # outputs even with the same inputs in training mode.
# model.eval()
model = AutoModel.from_pretrained(model_id, config=config)
model = TextGenerationModelWrapper(model)
return model, config
def create_text_gen_example_inputs(
config, batch_size: int = 2, seq_len: int = 3, past_seq_len: int = 2
):
"""Create example inputs and dynamic axes for ONNX export."""
config = config.text_config
num_hidden_layers = config.num_hidden_layers
# batch = "batch"
# sequence_len = "sequence_len"
# past_sequence_len = "past_sequence_len"
batch = torch.export.Dim("batch")
sequence_len = torch.export.Dim("sequence_len")
past_sequence_len = torch.export.Dim("past_sequence_len")
dynamic_shapes = {
"input_ids": {0: batch, 1: sequence_len},
"attention_mask": {
0: batch,
1: "past_sequence_len+sequence_len",
},
"position_ids": {
0: batch,
1: sequence_len,
},
"past_key_values": [
({0: batch, 2: past_sequence_len}, {0: batch, 2: past_sequence_len})
for _ in range(num_hidden_layers)
],
}
input_names = [
"input_ids",
"attention_mask",
"position_ids",
*[
name
for i in range(num_hidden_layers)
for name in (f"past_key_values.{i}.key", f"past_key_values.{i}.value")
],
]
output_names = [
"logits",
*[
name
for i in range(num_hidden_layers)
for name in (f"present.{i}.key", f"present.{i}.value")
],
]
num_key_value_heads = config.num_key_value_heads
head_dim = config.head_dim
example_inputs = dict(
input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
attention_mask=torch.ones(
(batch_size, past_seq_len + seq_len),
dtype=torch.int64,
),
position_ids=torch.arange(
past_seq_len,
past_seq_len + seq_len,
dtype=torch.int64,
).expand((batch_size, -1)),
past_key_values=[
(
torch.randn(
batch_size,
num_key_value_heads,
seq_len,
head_dim,
),
torch.randn(
batch_size,
num_key_value_heads,
seq_len,
head_dim,
),
)
for _ in range(num_hidden_layers)
],
)
example_inputs["past_key_values"] = make_dynamic_cache(
example_inputs["past_key_values"]
)
return example_inputs, dynamic_shapes, input_names, output_names
def make_dynamic_cache(
key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
cache = transformers.cache_utils.DynamicCache()
for i, (key, value) in enumerate(key_value_pairs):
cache.update(key, value, i)
return cache
class TextGenerationModelWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
self.dynamic_cache = None
def forward(
self,
input_ids,
attention_mask,
position_ids,
past_key_values,
):
hf_output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
)
return hf_output.last_hidden_state, hf_output.hidden_states
model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes, input_names, output_names = (
create_text_gen_example_inputs(config)
)
# ONNX Export
with torch._dynamo.config.patch(fake_tensor_cache_enabled=False):
onnx_program = torch.onnx.export(
model,
(),
kwargs=example_kwargs,
input_names=input_names,
output_names=output_names,
dynamic_shapes=dynamic_shapes,
opset_version=23,
dynamo=True,
report=True,
)
onnx_program.save("gemma3_4b_it.onnx", external_data=True)
onnx_output = onnx_program(**example_kwargs)
print(onnx_output)
print("------------")
output = model(**example_kwargs)
print(output)
Author
Giving up:
import torch
from transformers import AutoConfig, AutoModel, GenerationConfig
import transformers
# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"
def get_hf_model(model_id: str):
config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
# with torch.device('meta'):
# model = AutoModel.from_config(config)
# # This line is important. Some models may produce different
# # outputs even with the same inputs in training mode.
# model.eval()
model = AutoModel.from_pretrained(model_id, config=config)
generation_config = GenerationConfig.from_pretrained(
model_id, cache_implementation="static"
)
model.generation_config = generation_config
# model = TextGenerationModelWrapper(model)
return model, config
def create_text_gen_example_inputs(
config, batch_size: int = 2, seq_len: int = 8, past_seq_len: int = 2
):
"""Create example inputs and dynamic axes for ONNX export."""
config = config.text_config
num_hidden_layers = config.num_hidden_layers
# batch = "batch"
# sequence_len = "sequence_len"
# past_sequence_len = "past_sequence_len"
batch = torch.export.Dim("batch")
sequence_len = torch.export.Dim("sequence_len")
past_sequence_len = torch.export.Dim("past_sequence_len")
# dynamic_shapes = {
# "input_ids": {0: batch, 1: sequence_len},
# "attention_mask": {
# 0: batch,
# 1: "past_sequence_len+sequence_len",
# },
# "position_ids": {
# 0: batch,
# 1: sequence_len,
# },
# "past_key_values": [
# ({0: batch, 2: past_sequence_len}, {0: batch, 2: past_sequence_len})
# for _ in range(num_hidden_layers)
# ],
# }
dynamic_shapes = {
"input_ids": {0: batch, 1: sequence_len},
"cache_position": {
0: batch,
1: sequence_len,
},
}
input_names = [
"input_ids",
"attention_mask",
"position_ids",
*[
name
for i in range(num_hidden_layers)
for name in (f"past_key_values.{i}.key", f"past_key_values.{i}.value")
],
]
output_names = [
"logits",
*[
name
for i in range(num_hidden_layers)
for name in (f"present.{i}.key", f"present.{i}.value")
],
]
num_key_value_heads = config.num_key_value_heads
head_dim = config.head_dim
example_inputs = dict(
input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
attention_mask=torch.ones(
(batch_size, past_seq_len + seq_len),
dtype=torch.int64,
),
position_ids=torch.arange(
past_seq_len,
past_seq_len + seq_len,
dtype=torch.int64,
).expand((batch_size, -1)),
past_key_values=[
(
torch.randn(
batch_size,
num_key_value_heads,
seq_len,
head_dim,
),
torch.randn(
batch_size,
num_key_value_heads,
seq_len,
head_dim,
),
)
for _ in range(num_hidden_layers)
],
)
example_inputs["past_key_values"] = make_dynamic_cache(
config, example_inputs["past_key_values"]
)
return example_inputs, dynamic_shapes, input_names, output_names
def make_dynamic_cache(
config: transformers.PretrainedConfig,
key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
cache = transformers.cache_utils.DynamicCache(config=config)
for i, (key, value) in enumerate(key_value_pairs):
cache.update(key, value, i)
return cache
# class TextGenerationModelWrapper(torch.nn.Module):
# def __init__(self, model: torch.nn.Module):
# super().__init__()
# self.model = model
# def forward(
# self,
# input_ids,
# attention_mask,
# position_ids,
# past_key_values,
# ):
# hf_output = self.model(
# input_ids=input_ids,
# attention_mask=attention_mask,
# position_ids=position_ids,
# past_key_values=past_key_values,
# )
# return hf_output.last_hidden_state, hf_output.hidden_states
model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes, input_names, output_names = (
create_text_gen_example_inputs(config)
)
# Export
print("Exporting...")
with torch._dynamo.config.patch(fake_tensor_cache_enabled=False), torch.no_grad():
onnx_program = torch.onnx.export(
transformers.integrations.executorch.TorchExportableModuleWithStaticCache(
model, batch_size=2, max_cache_len=8
),
args=(),
kwargs={
"input_ids": example_kwargs["input_ids"],
"cache_position": example_kwargs["position_ids"],
},
dynamic_shapes=dynamic_shapes,
report=True,
)
# # ONNX Export
# transformers.integrations.executorch.register_dynamic_cache_export_support()
# with torch._dynamo.config.patch(fake_tensor_cache_enabled=False), torch.no_grad():
# onnx_program = torch.onnx.export(
# model,
# (),
# kwargs=example_kwargs,
# input_names=input_names,
# output_names=output_names,
# dynamic_shapes=dynamic_shapes,
# opset_version=23,
# dynamo=True,
# report=True,
# )
# onnx_program.save("gemma3_4b_it.onnx", external_data=True)
# onnx_output = onnx_program(**example_kwargs)
# print(onnx_output)
# print("------------")
# output = model(**example_kwargs)
# print(output)
Author
import torch
from transformers import AutoConfig, AutoModel
import transformers
# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"
def get_hf_model(model_id: str):
config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
# with torch.device('meta'):
# model = AutoModel.from_config(config)
# # This line is important. Some models may produce different
# # outputs even with the same inputs in training mode.
# model.eval()
model = AutoModel.from_pretrained(model_id, config=config)
model = TextGenerationModelWrapper(model)
return model, config
def create_text_gen_example_inputs(
config, batch_size: int = 2, seq_len: int = 3, past_seq_len: int = 2
):
"""Create example inputs and dynamic axes for ONNX export."""
config = config.text_config
num_hidden_layers = config.num_hidden_layers
# batch = "batch"
# sequence_len = "sequence_len"
# past_sequence_len = "past_sequence_len"
batch = torch.export.Dim("batch")
sequence_len = torch.export.Dim("sequence_len")
past_sequence_len = torch.export.Dim("past_sequence_len")
dynamic_shapes = {
"input_ids": {0: batch, 1: sequence_len},
"attention_mask": {
0: batch,
1: "past_sequence_len+sequence_len",
},
"position_ids": {
0: batch,
1: sequence_len,
},
"past_key_values": [
[{0: batch, 2: "past_sequence_len"} for _ in range(num_hidden_layers)],
[{0: batch, 2: "past_sequence_len"} for _ in range(num_hidden_layers)],
],
}
input_names = [
"input_ids",
"attention_mask",
"position_ids",
*[
name
for i in range(num_hidden_layers)
for name in (f"past_key_values.{i}.key", f"past_key_values.{i}.value")
],
]
output_names = [
"logits",
*[
name
for i in range(num_hidden_layers)
for name in (f"present.{i}.key", f"present.{i}.value")
],
]
num_key_value_heads = config.num_key_value_heads
head_dim = config.head_dim
example_inputs = dict(
input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
attention_mask=torch.ones(
(batch_size, past_seq_len + seq_len),
dtype=torch.int64,
),
position_ids=torch.arange(
past_seq_len,
past_seq_len + seq_len,
dtype=torch.int64,
).expand((batch_size, -1)),
past_key_values=make_dynamic_cache(
[
(
torch.randn(
batch_size,
num_key_value_heads,
seq_len,
head_dim,
),
torch.randn(
batch_size,
num_key_value_heads,
seq_len,
head_dim,
),
)
for _ in range(num_hidden_layers)
]
),
)
return example_inputs, dynamic_shapes, input_names, output_names
def make_dynamic_cache(
past_key_values: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
cache = transformers.cache_utils.DynamicCache()
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
class TextGenerationModelWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
def forward(
self,
input_ids,
attention_mask,
position_ids,
past_key_values,
):
hf_output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
)
return hf_output.last_hidden_state, hf_output.hidden_states
model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes, input_names, output_names = (
create_text_gen_example_inputs(config)
)
transformers.integrations.executorch.register_dynamic_cache_export_support()
# ONNX Export
# Disable fake tensor cache to avoid issues vmap
with torch._dynamo.config.patch(fake_tensor_cache_enabled=False):
onnx_program = torch.onnx.export(
model,
(),
kwargs=example_kwargs,
# input_names=input_names,
# output_names=output_names,
dynamic_shapes=dynamic_shapes,
opset_version=23,
dynamo=True,
report=True,
)
onnx_program.save("gemma3_4b_it.onnx", external_data=True)
onnx_output = onnx_program(**example_kwargs)
print(onnx_output)
print("------------")
output = model(**example_kwargs)
print(output)
Author
transformers==4.55.0
import torch
from transformers import AutoConfig, AutoModel
import transformers
# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"
def get_hf_model(model_id: str):
config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
config.use_cache = True
model = AutoModel.from_pretrained(model_id, config=config)
model = TextGenerationModelWrapper(model)
return model, config
def create_text_gen_example_inputs(
config, batch_size: int = 2, seq_len: int = 3, past_seq_len: int = 2
):
"""Create example inputs and dynamic axes for ONNX export."""
config = config.text_config
num_hidden_layers = config.num_hidden_layers
# batch = "batch"
# sequence_len = "sequence_len"
# past_sequence_len = "past_sequence_len"
batch = torch.export.Dim("batch")
sequence_len = torch.export.Dim("sequence_len")
# past_sequence_len = torch.export.Dim("past_sequence_len")
dynamic_shapes = {
"input_ids": {0: batch, 1: sequence_len},
"attention_mask": {
0: batch,
1: "past_sequence_len+sequence_len",
},
"position_ids": {
0: batch,
1: sequence_len,
},
"past_key_values": [
[{0: batch, 2: "past_sequence_len"} for _ in range(num_hidden_layers)],
[{0: batch, 2: "past_sequence_len"} for _ in range(num_hidden_layers)],
],
}
input_names = [
"input_ids",
"attention_mask",
"position_ids",
*[
f"past_key_values.{i}.key" for i in range(num_hidden_layers)
],
*[
f"past_key_values.{i}.value" for i in range(num_hidden_layers)
],
]
output_names = [
"logits",
*[
f"present_key_values.{i}.key" for i in range(num_hidden_layers)
],
*[
f"present_key_values.{i}.value" for i in range(num_hidden_layers)
],
]
num_key_value_heads = config.num_key_value_heads
head_dim = config.head_dim
example_inputs = dict(
input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
attention_mask=torch.ones(
(batch_size, past_seq_len + seq_len),
dtype=torch.int64,
),
position_ids=torch.arange(
past_seq_len,
past_seq_len + seq_len,
dtype=torch.int64,
).expand((batch_size, -1)),
past_key_values=make_dynamic_cache(
[
(
torch.randn(
batch_size,
num_key_value_heads,
seq_len,
head_dim,
),
torch.randn(
batch_size,
num_key_value_heads,
seq_len,
head_dim,
),
)
for _ in range(num_hidden_layers)
]
),
)
return example_inputs, dynamic_shapes, input_names, output_names
def make_dynamic_cache(
past_key_values: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
cache = transformers.cache_utils.DynamicCache()
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
class TextGenerationModelWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
def forward(
self,
input_ids,
attention_mask,
position_ids,
past_key_values,
):
hf_output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
)
return hf_output.last_hidden_state, hf_output.past_key_values
model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes, input_names, output_names = (
create_text_gen_example_inputs(config)
)
# transformers.integrations.executorch.register_dynamic_cache_export_support()
# ONNX Export
# Disable fake tensor cache to avoid issues vmap
with torch._dynamo.config.patch(fake_tensor_cache_enabled=False):
onnx_program = torch.onnx.export(
model,
(),
kwargs=example_kwargs,
input_names=input_names,
output_names=output_names,
dynamic_shapes=dynamic_shapes,
opset_version=23,
dynamo=True,
report=True,
)
print("export successful")
onnx_program.save("gemma3_4b_it.onnx", external_data=True)
print("model saved")
onnx_output = onnx_program(**example_kwargs)
print(onnx_output)
# print("------------")
# output = model(**example_kwargs)
# print(output)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.