Last active
October 28, 2025 23:42
-
-
Save justinchuby/ffd6844ddffaadf660a4a31f0cb04847 to your computer and use it in GitHub Desktop.
Qwen3EmbeddingONNXExporter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # transformers==4.52.4 | |
| # pytorch nightly | |
| from __future__ import annotations | |
| import ast | |
| import logging | |
| import typing | |
| from pathlib import Path | |
| import onnx_ir as ir | |
| import onnxscript.rewriter.ort_fusions | |
| import torch | |
| import torch.onnx.testing | |
| from onnx_ir.passes import PassResult | |
| from onnx_ir.passes.common import ClearMetadataAndDocStringPass | |
| from transformers import AutoModel, AutoTokenizer | |
| logger = logging.getLogger(__name__) | |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| """ | |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
| """ | |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
| if n_rep == 1: | |
| return hidden_states | |
| hidden_states = hidden_states[:, :, None, :, :].expand( | |
| batch, num_key_value_heads, n_rep, slen, head_dim | |
| ) | |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
| def sdpa_attention_forward( | |
| module: torch.nn.Module, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attention_mask: torch.Tensor | None, | |
| dropout: float = 0.0, | |
| scaling: float | None = None, | |
| is_causal: bool | None = None, | |
| **kwargs, | |
| ) -> tuple[torch.Tensor, None]: | |
| if hasattr(module, "num_key_value_groups"): | |
| key = repeat_kv(key, module.num_key_value_groups) | |
| value = repeat_kv(value, module.num_key_value_groups) | |
| causal_mask = attention_mask | |
| query = query.contiguous() | |
| key = key.contiguous() | |
| value = value.contiguous() | |
| attn_output = torch.nn.functional.scaled_dot_product_attention( | |
| query, | |
| key, | |
| value, | |
| attn_mask=causal_mask, | |
| dropout_p=dropout, | |
| scale=scaling, | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| return attn_output, None | |
| # Path sdpa attention for transformers | |
| import transformers.integrations.sdpa_attention | |
| transformers.integrations.sdpa_attention.sdpa_attention_forward = sdpa_attention_forward | |
| def _get_scoped_prefix(name_scopes: list[str]) -> str: | |
| # Remove common prefixes between consecutive scopes | |
| processed_scopes = [] | |
| for i, scope in enumerate(name_scopes): | |
| if i == 0: | |
| processed_scopes.append(scope) | |
| else: | |
| prev_scope = name_scopes[i - 1] | |
| processed_scopes.append(scope.removeprefix(prev_scope).lstrip(".")) | |
| return "/".join(processed_scopes) | |
| class AssignNamesPass(ir.passes.InPlacePass): | |
| def call(self, model: ir.Model) -> PassResult: | |
| modified = False | |
| for node in model.graph.all_nodes(): | |
| if "pkg.torch.onnx.name_scopes" in node.metadata_props: | |
| name_scopes = typing.cast( | |
| "list[str]", | |
| ast.literal_eval(node.metadata_props["pkg.torch.onnx.name_scopes"]), | |
| ) | |
| name_scopes.pop() # Remove self name | |
| prefix = _get_scoped_prefix(name_scopes) | |
| # Rename node | |
| if prefix: | |
| node.name = f"{prefix}/{node.name}" | |
| modified = True | |
| # Rename outputs | |
| for output in node.outputs: | |
| if ( | |
| not output.is_graph_output() | |
| and output.name is not None | |
| and output.name != "" | |
| ): | |
| if prefix: | |
| scoped_name = f"{prefix}/{output.name}" | |
| logger.debug("Renaming %r to %r", output.name, scoped_name) | |
| output.name = scoped_name | |
| modified = True | |
| return PassResult(model, modified) | |
| class Qwen3EmbeddingONNXExporter: | |
| def __init__(self, model_id="Qwen/Qwen3-Embedding-0.6B"): | |
| self.model_id = model_id | |
| self.model = None | |
| self.tokenizer = None | |
| def load_model(self): | |
| """Load the Qwen3 model and tokenizer""" | |
| print(f"Loading {self.model_id}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True) | |
| self.model = AutoModel.from_pretrained(self.model_id, trust_remote_code=True) | |
| self.model.eval() | |
| print("Model loaded successfully!") | |
| def create_dummy_inputs(self, batch_size=2, seq_length=128): | |
| """Create dummy inputs for ONNX export""" | |
| dummy_text = ["This is a sample text for ONNX export"] * batch_size | |
| inputs = self.tokenizer( | |
| dummy_text, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=seq_length, | |
| ) | |
| return inputs | |
| def export_to_onnx(self, output_dir="./qwen3-onnx"): | |
| """Export model to ONNX format""" | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| # Save tokenizer and config | |
| print("Saving tokenizer and config...") | |
| self.tokenizer.save_pretrained(output_dir) | |
| # Create dummy inputs | |
| # NOTE(justinchuby): Batch size must be greater than 1 to be captured as dynamic | |
| dummy_inputs = self.create_dummy_inputs() | |
| input_ids = dummy_inputs["input_ids"] | |
| attention_mask = dummy_inputs["attention_mask"] | |
| # Export to ONNX | |
| print("Exporting to ONNX...") | |
| # Wrap the model WITHOUT pooling - TEI will handle pooling | |
| class ModelWrapper(torch.nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
| if hasattr(outputs, "last_hidden_state"): | |
| return outputs.last_hidden_state | |
| else: | |
| return outputs[0] | |
| wrapped_model = ModelWrapper(self.model) | |
| wrapped_model.eval() | |
| onnx_program = torch.onnx.export( | |
| wrapped_model, | |
| (input_ids, attention_mask), | |
| input_names=["input_ids", "attention_mask"], | |
| output_names=["last_hidden_state"], | |
| dynamic_shapes={ | |
| "input_ids": {0: "batch", 1: "seq"}, | |
| "attention_mask": {0: "batch", 1: "seq"}, | |
| }, | |
| opset_version=21, | |
| ) | |
| AssignNamesPass()(onnx_program.model) | |
| onnx_program.save(output_path / "model_pre_fusion.onnx") | |
| torch.onnx.testing.assert_onnx_program(onnx_program, atol=1e-4, rtol=1e-4) | |
| # Optimize for ORT | |
| model, fusion = onnxscript.rewriter.ort_fusions.optimize_for_ort(onnx_program.model) | |
| print(fusion) | |
| # For production, remove metadata: | |
| result = ClearMetadataAndDocStringPass()(model) | |
| onnx_program.model = result.model | |
| onnx_program.save(output_path / "model.onnx") | |
| torch.onnx.testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1e-3) | |
| def main(): | |
| exporter = Qwen3EmbeddingONNXExporter() | |
| exporter.load_model() | |
| exporter.export_to_onnx(output_dir="./qwen3-onnx-1028") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment