Created
November 5, 2025 06:03
-
-
Save justinchuby/3328106783e28fbccb3e6bd1673691a8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import onnx_ir as ir | |
| import onnx_ir.passes.common | |
| import onnxscript | |
| m = ir.load("perch_v2_opt3.onnx") | |
| for node in m.graph: | |
| if node.op_type == "MatMul": | |
| print(node) | |
| if node.inputs[0].producer().op_type == "Reshape": | |
| # Skip the reshape | |
| input = node.inputs[0].producer().inputs[0] | |
| node.replace_input_with(0, input) | |
| for usage in node.outputs[0].uses(): | |
| if usage.node.op_type == "Reshape": | |
| reshape_usages = list(usage.node.outputs[0].uses()) | |
| # Keep the last Reshape | |
| if reshape_usages[0].node.op_type == "ReduceMax": | |
| shape = ir.val( | |
| "reshape_shape", const_value=ir.tensor([-1, 16, 4, 14795, 4]) | |
| ) | |
| m.graph.initializers.add(shape) | |
| usage.node.replace_input_with(1, shape) | |
| continue | |
| reshape_node = usage.node | |
| output = reshape_node.outputs[0] | |
| output.replace_all_uses_with(node.outputs[0]) | |
| # Remove Expand | |
| if node.op_type == "Expand": | |
| print(node) | |
| input = node.inputs[0] | |
| output = node.outputs[0] | |
| output.replace_all_uses_with(input) | |
| # Clean up any unused nodes | |
| onnx_ir.passes.common.RemoveUnusedNodesPass()(m) | |
| # Clear all intermediate shapes and re-infer shapes | |
| for node in m.graph: | |
| for output in node.outputs: | |
| if output.is_graph_output(): | |
| continue | |
| output.shape = None | |
| m.graph.inputs[0].shape = ir.Shape(["batch", *m.graph.inputs[0].shape[1:]]) | |
| for output in m.graph.outputs: | |
| output.shape = ir.Shape(["batch", *output.shape[1:]]) | |
| onnxscript.optimizer.optimize( | |
| m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 | |
| ) | |
| onnx_ir.passes.common.ClearMetadataAndDocStringPass()(m) | |
| # Rename outputs and match the tflite model | |
| m.graph.outputs[0].name = "spatial_embedding" | |
| m.graph.outputs[1].name = "embedding" | |
| m.graph.outputs[2].name = "spectrogram" | |
| m.graph.outputs[3].name = "label" | |
| out_0 = m.graph.outputs[0] | |
| out_1 = m.graph.outputs[1] | |
| m.graph.outputs[1] = out_0 | |
| m.graph.outputs[0] = out_1 | |
| ir.save(m, "perch_v2_opt4.onnx") |
Author
justinchuby
commented
Nov 5, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment