Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Last active October 14, 2025 04:33
Show Gist options
  • Select an option

  • Save justinchuby/b817e87d8f4c9ae701ece75f2792b6de to your computer and use it in GitHub Desktop.

Select an option

Save justinchuby/b817e87d8f4c9ae701ece75f2792b6de to your computer and use it in GitHub Desktop.
Export HF model to ONNX
"""Export to ONNX.
transformers_version == "4.52.0"
"""
import onnx_diagnostic.tasks.text_generation
import torch
from transformers import AutoConfig, AutoModel
import onnxscript
import onnxscript.rewriter.onnx_fusions
# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"
def get_hf_model(model_id: str):
config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
# with torch.device('meta'):
# model = AutoModel.from_config(config)
# # This line is important. Some models may produce different
# # outputs even with the same inputs in training mode.
# model.eval()
model = AutoModel.from_pretrained(model_id, config=config)
return model, config
def get_model_input(model: torch.nn.Module, config):
# with torch.device('meta'):
result = onnx_diagnostic.tasks.text_generation.get_inputs(
model,
config,
dummy_max_token_id=10,
num_hidden_layers=34,
num_key_value_heads=4,
head_dim=256,
cls_cache="DynamicCache"
)
return result["inputs"], result["dynamic_shapes"]
model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes = get_model_input(model, config)
# ONNX Export
onnx_program = torch.onnx.export(
model,
(),
kwargs=example_kwargs,
dynamic_shapes=dynamic_shapes,
opset_version=23,
dynamo=True,
# profile=True,
)
onnx_program.save("gemma3_4b_it.onnx")
# onnxscript.rewriter.onnx_fusions.fuse(onnx_program.model) # This is performed as part of the optimize=True option
@justinchuby
Copy link
Author

justinchuby commented Aug 27, 2025

from torch._subclasses.fake_tensor import FakeTensorMode
 
with FakeTensorMode():
    model = AutoModel.from_config(config)

@justinchuby
Copy link
Author

"""Export to ONNX.

transformers_version == "4.52.0"
"""

import onnx_diagnostic.tasks.text_generation
import torch
from transformers import AutoConfig, AutoModel
from torch._subclasses.fake_tensor import FakeTensorMode

import onnxscript
import onnxscript.rewriter.onnx_fusions

# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"

def get_hf_model(model_id: str):
    config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
    # with torch.device('meta'):
    #     model = AutoModel.from_config(config)
    #     # This line is important. Some models may produce different
    #     # outputs even with the same inputs in training mode.
    #     model.eval()
    model = AutoModel.from_pretrained(model_id, config=config)

    return model, config

def get_model_input(model: torch.nn.Module, config):
    # with torch.device('meta'):
    result = onnx_diagnostic.tasks.text_generation.get_inputs(
        model,
        config,
        dummy_max_token_id=10,
        num_hidden_layers=34,
        num_key_value_heads=4,
        head_dim=256,
        cls_cache="DynamicCache"
    )
    return result["inputs"], result["dynamic_shapes"]



# with FakeTensorMode():
model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes = get_model_input(model, config)

# ONNX Export
onnx_program = torch.onnx.export(
    model,
    (),
    kwargs=example_kwargs,
    dynamic_shapes=dynamic_shapes,
    opset_version=23,
    dynamo=True,
    report=True,
    optimize=False,
    # profile=True,
)

print("Saving ONNX model...")
onnx_program.save("gemma3_4b_it_original.onnx", include_initializers=False, keep_initializers_as_inputs=True)

print("Optimizing ONNX model...")
onnx_program.optimize()

print("Saving optimized ONNX model...")
onnx_program.save("gemma3_4b_it_optimized.onnx", include_initializers=False, keep_initializers_as_inputs=True)

print("Fusing ONNX model...")
onnx_program.model.functions.clear()
onnxscript.optimizer.optimize(onnx_program.model, num_iterations=2, input_size_limit=10000)
onnxscript.rewriter.onnx_fusions.fuse(onnx_program.model)
onnxscript.optimizer.optimize(onnx_program.model, num_iterations=2, input_size_limit=10000)

print("Saving fused ONNX model...")
onnx_program.save("gemma3_4b_it_fused.onnx", include_initializers=False, keep_initializers_as_inputs=True)

@justinchuby
Copy link
Author

justinchuby commented Oct 13, 2025

transformers==4.56.0

import torch
from transformers import AutoConfig, AutoModel
import transformers

# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"


def get_hf_model(model_id: str):
    config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
    # with torch.device('meta'):
    #     model = AutoModel.from_config(config)
    #     # This line is important. Some models may produce different
    #     # outputs even with the same inputs in training mode.
    #     model.eval()
    model = AutoModel.from_pretrained(model_id, config=config)
    model = TextGenerationModelWrapper(model)

    return model, config


def create_text_gen_example_inputs(
    config, batch_size: int = 2, seq_len: int = 3, past_seq_len: int = 2
):
    """Create example inputs and dynamic axes for ONNX export."""
    config = config.text_config
    num_hidden_layers = config.num_hidden_layers
    batch = "batch"
    sequence_len = "sequence_len"
    past_sequence_len = "past_sequence_len"

    dynamic_shapes = {
        "input_ids": {0: batch, 1: sequence_len},
        "attention_mask": {
            0: batch,
            1: "past_sequence_len+sequence_len",
        },
        "position_ids": {
            0: batch,
            1: sequence_len,
        },
        "past_key_values": [
            ({0: batch, 2: past_sequence_len}, {0: batch, 2: past_sequence_len})
            for _ in range(num_hidden_layers)
        ],
    }
    input_names = [
        "input_ids",
        "attention_mask",
        "position_ids",
        *[
            name
            for i in range(num_hidden_layers)
            for name in (f"past_key_values.{i}.key", f"past_key_values.{i}.value")
        ],
    ]
    output_names = [
        "logits",
        *[
            name
            for i in range(num_hidden_layers)
            for name in (f"present.{i}.key", f"present.{i}.value")
        ],
    ]

    num_key_value_heads = config.num_key_value_heads
    head_dim = config.head_dim

    example_inputs = dict(
        input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
        attention_mask=torch.ones(
            (batch_size, past_seq_len + seq_len),
            dtype=torch.int64,
        ),
        position_ids=torch.arange(
            past_seq_len,
            past_seq_len + seq_len,
            dtype=torch.int64,
        ).expand((batch_size, -1)),
        past_key_values=[
            (
                torch.randn(
                    batch_size,
                    num_key_value_heads,
                    seq_len,
                    head_dim,
                ),
                torch.randn(
                    batch_size,
                    num_key_value_heads,
                    seq_len,
                    head_dim,
                ),
            )
            for _ in range(num_hidden_layers)
        ],
    )

    return example_inputs, dynamic_shapes, input_names, output_names


def make_dynamic_cache(
    key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
    cache = transformers.cache_utils.DynamicCache(key_value_pairs)
    for i, (key, value) in enumerate(key_value_pairs):
        cache.update(key, value, i)
    return cache


class TextGenerationModelWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
        super().__init__()
        self.model = model

    def forward(
        self,
        input_ids,
        attention_mask,
        position_ids,
        past_key_values: list[tuple[torch.Tensor, torch.Tensor]],
    ):
        hf_output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=make_dynamic_cache(past_key_values),
        )
        return hf_output.last_hidden_state, hf_output.hidden_states


model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes, input_names, output_names = (
    create_text_gen_example_inputs(config)
)

# ONNX Export
with torch._dynamo.config.patch(fake_tensor_cache_enabled=False):
    onnx_program = torch.onnx.export(
        model,
        (),
        kwargs=example_kwargs,
        input_names=input_names,
        output_names=output_names,
        dynamic_shapes=dynamic_shapes,
        opset_version=23,
        dynamo=True,
        report=True,
    )

onnx_program.save("gemma3_4b_it.onnx", external_data=True)

onnx_output = onnx_program.call_reference(**example_kwargs)

print(onnx_output)

print("------------")

output = model(**example_kwargs)
print(output)

@justinchuby
Copy link
Author

import torch
from transformers import AutoConfig, AutoModel
import transformers

# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"


def get_hf_model(model_id: str):
    config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
    # with torch.device('meta'):
    #     model = AutoModel.from_config(config)
    #     # This line is important. Some models may produce different
    #     # outputs even with the same inputs in training mode.
    #     model.eval()
    model = AutoModel.from_pretrained(model_id, config=config)
    model = TextGenerationModelWrapper(model)

    return model, config


def create_text_gen_example_inputs(
    config, batch_size: int = 2, seq_len: int = 3, past_seq_len: int = 2
):
    """Create example inputs and dynamic axes for ONNX export."""
    config = config.text_config
    num_hidden_layers = config.num_hidden_layers
    # batch = "batch"
    # sequence_len = "sequence_len"
    # past_sequence_len = "past_sequence_len"
    batch = torch.export.Dim("batch")
    sequence_len = torch.export.Dim("sequence_len")
    past_sequence_len = torch.export.Dim("past_sequence_len")

    dynamic_shapes = {
        "input_ids": {0: batch, 1: sequence_len},
        "attention_mask": {
            0: batch,
            1: "past_sequence_len+sequence_len",
        },
        "position_ids": {
            0: batch,
            1: sequence_len,
        },
        "past_key_values": [
            ({0: batch, 2: past_sequence_len}, {0: batch, 2: past_sequence_len})
            for _ in range(num_hidden_layers)
        ],
    }
    input_names = [
        "input_ids",
        "attention_mask",
        "position_ids",
        *[
            name
            for i in range(num_hidden_layers)
            for name in (f"past_key_values.{i}.key", f"past_key_values.{i}.value")
        ],
    ]
    output_names = [
        "logits",
        *[
            name
            for i in range(num_hidden_layers)
            for name in (f"present.{i}.key", f"present.{i}.value")
        ],
    ]

    num_key_value_heads = config.num_key_value_heads
    head_dim = config.head_dim

    example_inputs = dict(
        input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
        attention_mask=torch.ones(
            (batch_size, past_seq_len + seq_len),
            dtype=torch.int64,
        ),
        position_ids=torch.arange(
            past_seq_len,
            past_seq_len + seq_len,
            dtype=torch.int64,
        ).expand((batch_size, -1)),
        past_key_values=[
            (
                torch.randn(
                    batch_size,
                    num_key_value_heads,
                    seq_len,
                    head_dim,
                ),
                torch.randn(
                    batch_size,
                    num_key_value_heads,
                    seq_len,
                    head_dim,
                ),
            )
            for _ in range(num_hidden_layers)
        ],
    )

    example_inputs["past_key_values"] = make_dynamic_cache(
        example_inputs["past_key_values"]
    )

    return example_inputs, dynamic_shapes, input_names, output_names


def make_dynamic_cache(
    key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
    cache = transformers.cache_utils.DynamicCache()
    for i, (key, value) in enumerate(key_value_pairs):
        cache.update(key, value, i)
    return cache


class TextGenerationModelWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
        super().__init__()
        self.model = model
        self.dynamic_cache = None

    def forward(
        self,
        input_ids,
        attention_mask,
        position_ids,
        past_key_values,
    ):
        hf_output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
        )
        return hf_output.last_hidden_state, hf_output.hidden_states


model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes, input_names, output_names = (
    create_text_gen_example_inputs(config)
)

# ONNX Export
with torch._dynamo.config.patch(fake_tensor_cache_enabled=False):
    onnx_program = torch.onnx.export(
        model,
        (),
        kwargs=example_kwargs,
        input_names=input_names,
        output_names=output_names,
        dynamic_shapes=dynamic_shapes,
        opset_version=23,
        dynamo=True,
        report=True,
    )

onnx_program.save("gemma3_4b_it.onnx", external_data=True)

onnx_output = onnx_program(**example_kwargs)

print(onnx_output)

print("------------")

output = model(**example_kwargs)
print(output)

@justinchuby
Copy link
Author

Giving up:

import torch
from transformers import AutoConfig, AutoModel, GenerationConfig
import transformers

# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"


def get_hf_model(model_id: str):
    config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
    # with torch.device('meta'):
    #     model = AutoModel.from_config(config)
    #     # This line is important. Some models may produce different
    #     # outputs even with the same inputs in training mode.
    #     model.eval()
    model = AutoModel.from_pretrained(model_id, config=config)
    generation_config = GenerationConfig.from_pretrained(
        model_id, cache_implementation="static"
    )
    model.generation_config = generation_config
    # model = TextGenerationModelWrapper(model)

    return model, config


def create_text_gen_example_inputs(
    config, batch_size: int = 2, seq_len: int = 8, past_seq_len: int = 2
):
    """Create example inputs and dynamic axes for ONNX export."""
    config = config.text_config
    num_hidden_layers = config.num_hidden_layers
    # batch = "batch"
    # sequence_len = "sequence_len"
    # past_sequence_len = "past_sequence_len"
    batch = torch.export.Dim("batch")
    sequence_len = torch.export.Dim("sequence_len")
    past_sequence_len = torch.export.Dim("past_sequence_len")

    # dynamic_shapes = {
    #     "input_ids": {0: batch, 1: sequence_len},
    #     "attention_mask": {
    #         0: batch,
    #         1: "past_sequence_len+sequence_len",
    #     },
    #     "position_ids": {
    #         0: batch,
    #         1: sequence_len,
    #     },
    #     "past_key_values": [
    #         ({0: batch, 2: past_sequence_len}, {0: batch, 2: past_sequence_len})
    #         for _ in range(num_hidden_layers)
    #     ],
    # }
    dynamic_shapes = {
        "input_ids": {0: batch, 1: sequence_len},
        "cache_position": {
            0: batch,
            1: sequence_len,
        },
    }
    input_names = [
        "input_ids",
        "attention_mask",
        "position_ids",
        *[
            name
            for i in range(num_hidden_layers)
            for name in (f"past_key_values.{i}.key", f"past_key_values.{i}.value")
        ],
    ]
    output_names = [
        "logits",
        *[
            name
            for i in range(num_hidden_layers)
            for name in (f"present.{i}.key", f"present.{i}.value")
        ],
    ]

    num_key_value_heads = config.num_key_value_heads
    head_dim = config.head_dim

    example_inputs = dict(
        input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
        attention_mask=torch.ones(
            (batch_size, past_seq_len + seq_len),
            dtype=torch.int64,
        ),
        position_ids=torch.arange(
            past_seq_len,
            past_seq_len + seq_len,
            dtype=torch.int64,
        ).expand((batch_size, -1)),
        past_key_values=[
            (
                torch.randn(
                    batch_size,
                    num_key_value_heads,
                    seq_len,
                    head_dim,
                ),
                torch.randn(
                    batch_size,
                    num_key_value_heads,
                    seq_len,
                    head_dim,
                ),
            )
            for _ in range(num_hidden_layers)
        ],
    )

    example_inputs["past_key_values"] = make_dynamic_cache(
        config, example_inputs["past_key_values"]
    )

    return example_inputs, dynamic_shapes, input_names, output_names


def make_dynamic_cache(
    config: transformers.PretrainedConfig,
    key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
    cache = transformers.cache_utils.DynamicCache(config=config)
    for i, (key, value) in enumerate(key_value_pairs):
        cache.update(key, value, i)
    return cache


# class TextGenerationModelWrapper(torch.nn.Module):
#     def __init__(self, model: torch.nn.Module):
#         super().__init__()
#         self.model = model

#     def forward(
#         self,
#         input_ids,
#         attention_mask,
#         position_ids,
#         past_key_values,
#     ):
#         hf_output = self.model(
#             input_ids=input_ids,
#             attention_mask=attention_mask,
#             position_ids=position_ids,
#             past_key_values=past_key_values,
#         )
#         return hf_output.last_hidden_state, hf_output.hidden_states


model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes, input_names, output_names = (
    create_text_gen_example_inputs(config)
)

# Export

print("Exporting...")
with torch._dynamo.config.patch(fake_tensor_cache_enabled=False), torch.no_grad():
    onnx_program = torch.onnx.export(
        transformers.integrations.executorch.TorchExportableModuleWithStaticCache(
            model, batch_size=2, max_cache_len=8
        ),
        args=(),
        kwargs={
            "input_ids": example_kwargs["input_ids"],
            "cache_position": example_kwargs["position_ids"],
        },
        dynamic_shapes=dynamic_shapes,
        report=True,
    )

# # ONNX Export
# transformers.integrations.executorch.register_dynamic_cache_export_support()
# with torch._dynamo.config.patch(fake_tensor_cache_enabled=False), torch.no_grad():
#     onnx_program = torch.onnx.export(
#         model,
#         (),
#         kwargs=example_kwargs,
#         input_names=input_names,
#         output_names=output_names,
#         dynamic_shapes=dynamic_shapes,
#         opset_version=23,
#         dynamo=True,
#         report=True,
#     )

# onnx_program.save("gemma3_4b_it.onnx", external_data=True)

# onnx_output = onnx_program(**example_kwargs)

# print(onnx_output)

# print("------------")

# output = model(**example_kwargs)
# print(output)

@justinchuby
Copy link
Author

import torch
from transformers import AutoConfig, AutoModel
import transformers

# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"


def get_hf_model(model_id: str):
    config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
    # with torch.device('meta'):
    #     model = AutoModel.from_config(config)
    #     # This line is important. Some models may produce different
    #     # outputs even with the same inputs in training mode.
    #     model.eval()
    model = AutoModel.from_pretrained(model_id, config=config)
    model = TextGenerationModelWrapper(model)

    return model, config


def create_text_gen_example_inputs(
    config, batch_size: int = 2, seq_len: int = 3, past_seq_len: int = 2
):
    """Create example inputs and dynamic axes for ONNX export."""
    config = config.text_config
    num_hidden_layers = config.num_hidden_layers
    # batch = "batch"
    # sequence_len = "sequence_len"
    # past_sequence_len = "past_sequence_len"
    batch = torch.export.Dim("batch")
    sequence_len = torch.export.Dim("sequence_len")
    past_sequence_len = torch.export.Dim("past_sequence_len")

    dynamic_shapes = {
        "input_ids": {0: batch, 1: sequence_len},
        "attention_mask": {
            0: batch,
            1: "past_sequence_len+sequence_len",
        },
        "position_ids": {
            0: batch,
            1: sequence_len,
        },
        "past_key_values": [
            [{0: batch, 2: "past_sequence_len"} for _ in range(num_hidden_layers)],
            [{0: batch, 2: "past_sequence_len"} for _ in range(num_hidden_layers)],
        ],
    }
    input_names = [
        "input_ids",
        "attention_mask",
        "position_ids",
        *[
            name
            for i in range(num_hidden_layers)
            for name in (f"past_key_values.{i}.key", f"past_key_values.{i}.value")
        ],
    ]
    output_names = [
        "logits",
        *[
            name
            for i in range(num_hidden_layers)
            for name in (f"present.{i}.key", f"present.{i}.value")
        ],
    ]

    num_key_value_heads = config.num_key_value_heads
    head_dim = config.head_dim

    example_inputs = dict(
        input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
        attention_mask=torch.ones(
            (batch_size, past_seq_len + seq_len),
            dtype=torch.int64,
        ),
        position_ids=torch.arange(
            past_seq_len,
            past_seq_len + seq_len,
            dtype=torch.int64,
        ).expand((batch_size, -1)),
        past_key_values=make_dynamic_cache(
            [
                (
                    torch.randn(
                        batch_size,
                        num_key_value_heads,
                        seq_len,
                        head_dim,
                    ),
                    torch.randn(
                        batch_size,
                        num_key_value_heads,
                        seq_len,
                        head_dim,
                    ),
                )
                for _ in range(num_hidden_layers)
            ]
        ),
    )

    return example_inputs, dynamic_shapes, input_names, output_names


def make_dynamic_cache(
    past_key_values: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
    cache = transformers.cache_utils.DynamicCache()
    for layer_idx in range(len(past_key_values)):
        key_states, value_states = past_key_values[layer_idx]
        cache.update(key_states, value_states, layer_idx)
    return cache


class TextGenerationModelWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
        super().__init__()
        self.model = model

    def forward(
        self,
        input_ids,
        attention_mask,
        position_ids,
        past_key_values,
    ):
        hf_output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
        )
        return hf_output.last_hidden_state, hf_output.hidden_states


model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes, input_names, output_names = (
    create_text_gen_example_inputs(config)
)

transformers.integrations.executorch.register_dynamic_cache_export_support()

# ONNX Export
# Disable fake tensor cache to avoid issues vmap
with torch._dynamo.config.patch(fake_tensor_cache_enabled=False):
    onnx_program = torch.onnx.export(
        model,
        (),
        kwargs=example_kwargs,
        # input_names=input_names,
        # output_names=output_names,
        dynamic_shapes=dynamic_shapes,
        opset_version=23,
        dynamo=True,
        report=True,
    )

onnx_program.save("gemma3_4b_it.onnx", external_data=True)

onnx_output = onnx_program(**example_kwargs)

print(onnx_output)

print("------------")

output = model(**example_kwargs)
print(output)

@justinchuby
Copy link
Author

justinchuby commented Oct 14, 2025

transformers==4.55.0

import torch
from transformers import AutoConfig, AutoModel
import transformers

# MODEL_ID = "google/gemma-2b"
# MODEL_ID = "google/gemma-3-270m-it"
# MODEL_ID = "google/gemma-3-270m"
MODEL_ID = "google/gemma-3-4b-it"


def get_hf_model(model_id: str):
    config = AutoConfig.from_pretrained(model_id, attn_implementation="sdpa")
    config.use_cache = True
    model = AutoModel.from_pretrained(model_id, config=config)
    model = TextGenerationModelWrapper(model)

    return model, config


def create_text_gen_example_inputs(
    config, batch_size: int = 2, seq_len: int = 3, past_seq_len: int = 2
):
    """Create example inputs and dynamic axes for ONNX export."""
    config = config.text_config
    num_hidden_layers = config.num_hidden_layers
    # batch = "batch"
    # sequence_len = "sequence_len"
    # past_sequence_len = "past_sequence_len"
    batch = torch.export.Dim("batch")
    sequence_len = torch.export.Dim("sequence_len")
    # past_sequence_len = torch.export.Dim("past_sequence_len")

    dynamic_shapes = {
        "input_ids": {0: batch, 1: sequence_len},
        "attention_mask": {
            0: batch,
            1: "past_sequence_len+sequence_len",
        },
        "position_ids": {
            0: batch,
            1: sequence_len,
        },
        "past_key_values": [
            [{0: batch, 2: "past_sequence_len"} for _ in range(num_hidden_layers)],
            [{0: batch, 2: "past_sequence_len"} for _ in range(num_hidden_layers)],
        ],
    }
    input_names = [
        "input_ids",
        "attention_mask",
        "position_ids",
        *[
            f"past_key_values.{i}.key" for i in range(num_hidden_layers)
        ],
        *[
            f"past_key_values.{i}.value" for i in range(num_hidden_layers)
        ],
    ]
    output_names = [
        "logits",
        *[
            f"present_key_values.{i}.key" for i in range(num_hidden_layers)
        ],
        *[
            f"present_key_values.{i}.value" for i in range(num_hidden_layers)
        ],
    ]

    num_key_value_heads = config.num_key_value_heads
    head_dim = config.head_dim

    example_inputs = dict(
        input_ids=torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64),
        attention_mask=torch.ones(
            (batch_size, past_seq_len + seq_len),
            dtype=torch.int64,
        ),
        position_ids=torch.arange(
            past_seq_len,
            past_seq_len + seq_len,
            dtype=torch.int64,
        ).expand((batch_size, -1)),
        past_key_values=make_dynamic_cache(
            [
                (
                    torch.randn(
                        batch_size,
                        num_key_value_heads,
                        seq_len,
                        head_dim,
                    ),
                    torch.randn(
                        batch_size,
                        num_key_value_heads,
                        seq_len,
                        head_dim,
                    ),
                )
                for _ in range(num_hidden_layers)
            ]
        ),
    )

    return example_inputs, dynamic_shapes, input_names, output_names


def make_dynamic_cache(
    past_key_values: list[tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
    cache = transformers.cache_utils.DynamicCache()
    for layer_idx in range(len(past_key_values)):
        key_states, value_states = past_key_values[layer_idx]
        cache.update(key_states, value_states, layer_idx)
    return cache


class TextGenerationModelWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
        super().__init__()
        self.model = model

    def forward(
        self,
        input_ids,
        attention_mask,
        position_ids,
        past_key_values,
    ):
        hf_output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
        )
        return hf_output.last_hidden_state, hf_output.past_key_values


model, config = get_hf_model(MODEL_ID)
example_kwargs, dynamic_shapes, input_names, output_names = (
    create_text_gen_example_inputs(config)
)

# transformers.integrations.executorch.register_dynamic_cache_export_support()

# ONNX Export
# Disable fake tensor cache to avoid issues vmap
with torch._dynamo.config.patch(fake_tensor_cache_enabled=False):
    onnx_program = torch.onnx.export(
        model,
        (),
        kwargs=example_kwargs,
        input_names=input_names,
        output_names=output_names,
        dynamic_shapes=dynamic_shapes,
        opset_version=23,
        dynamo=True,
        report=True,
    )

print("export successful")

onnx_program.save("gemma3_4b_it.onnx", external_data=True)

print("model saved")

onnx_output = onnx_program(**example_kwargs)

print(onnx_output)

# print("------------")

# output = model(**example_kwargs)
# print(output)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment