Created
November 5, 2025 06:03
-
-
Save justinchuby/3328106783e28fbccb3e6bd1673691a8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import onnx_ir as ir | |
| import onnx_ir.passes.common | |
| import onnxscript | |
| m = ir.load("perch_v2_opt3.onnx") | |
| for node in m.graph: | |
| if node.op_type == "MatMul": | |
| print(node) | |
| if node.inputs[0].producer().op_type == "Reshape": | |
| # Skip the reshape | |
| input = node.inputs[0].producer().inputs[0] | |
| node.replace_input_with(0, input) | |
| for usage in node.outputs[0].uses(): | |
| if usage.node.op_type == "Reshape": | |
| reshape_usages = list(usage.node.outputs[0].uses()) | |
| # Keep the last Reshape | |
| if reshape_usages[0].node.op_type == "ReduceMax": | |
| shape = ir.val( | |
| "reshape_shape", const_value=ir.tensor([-1, 16, 4, 14795, 4]) | |
| ) | |
| m.graph.initializers.add(shape) | |
| usage.node.replace_input_with(1, shape) | |
| continue | |
| reshape_node = usage.node | |
| output = reshape_node.outputs[0] | |
| output.replace_all_uses_with(node.outputs[0]) | |
| # Remove Expand | |
| if node.op_type == "Expand": | |
| print(node) | |
| input = node.inputs[0] | |
| output = node.outputs[0] | |
| output.replace_all_uses_with(input) | |
| # Clean up any unused nodes | |
| onnx_ir.passes.common.RemoveUnusedNodesPass()(m) | |
| # Clear all intermediate shapes and re-infer shapes | |
| for node in m.graph: | |
| for output in node.outputs: | |
| if output.is_graph_output(): | |
| continue | |
| output.shape = None | |
| m.graph.inputs[0].shape = ir.Shape(["batch", *m.graph.inputs[0].shape[1:]]) | |
| for output in m.graph.outputs: | |
| output.shape = ir.Shape(["batch", *output.shape[1:]]) | |
| onnxscript.optimizer.optimize( | |
| m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 | |
| ) | |
| onnx_ir.passes.common.ClearMetadataAndDocStringPass()(m) | |
| # Rename outputs and match the tflite model | |
| m.graph.outputs[0].name = "spatial_embedding" | |
| m.graph.outputs[1].name = "embedding" | |
| m.graph.outputs[2].name = "spectrogram" | |
| m.graph.outputs[3].name = "label" | |
| out_0 = m.graph.outputs[0] | |
| out_1 = m.graph.outputs[1] | |
| m.graph.outputs[1] = out_0 | |
| m.graph.outputs[0] = out_1 | |
| ir.save(m, "perch_v2_opt4.onnx") |
Author
justinchuby
commented
Nov 5, 2025
Author
import onnx_ir as ir
import onnx_ir.passes.common
import onnxscript
import numpy as np
m = ir.load("perch_v2_opt3.onnx")
for node in m.graph:
if node.op_type == "MatMul":
print("Simplify MatMul + Reshape:", node.name)
if node.inputs[0].producer().op_type == "Reshape":
# Skip the reshape
input = node.inputs[0].producer().inputs[0]
node.replace_input_with(0, input)
for usage in node.outputs[0].uses():
if usage.node.op_type == "Reshape":
reshape_usages = list(usage.node.outputs[0].uses())
# Keep the last Reshape
if reshape_usages[0].node.op_type == "ReduceMax":
shape = ir.val(
"reshape_shape", const_value=ir.tensor([-1, 16, 4, 14795, 4])
)
m.graph.initializers.add(shape)
usage.node.replace_input_with(1, shape)
continue
reshape_node = usage.node
output = reshape_node.outputs[0]
output.replace_all_uses_with(node.outputs[0])
# Remove Expand
if node.op_type == "Expand":
print("Remove Expand:", node.name)
input = node.inputs[0]
output = node.outputs[0]
output.replace_all_uses_with(input)
# Clean up any unused nodes
onnx_ir.passes.common.RemoveUnusedNodesPass()(m)
# Do some const folding
onnxscript.optimizer.optimize(
m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
one_1d = ir.val("1d_one", const_value=ir.tensor([1], dtype=ir.DataType.INT64))
m.graph.initializers.add(one_1d)
# Simplify Unsqueeze + Reshape
for node in m.graph:
if node.op_type == "Reshape":
print("Simplify Unsqueeze + Reshape:", node.name)
if (
node.inputs[0].producer()
and node.inputs[0].producer().op_type == "Unsqueeze"
):
unsqueeze_node = node.inputs[0].producer()
unsqueeze_node.replace_input_with(1, one_1d)
node.outputs[0].replace_all_uses_with(unsqueeze_node.outputs[0])
# Fuse Conv + Sub into Conv
for node in m.graph:
if node.op_type == "Conv":
print("Check Conv for fusion:", node.name)
conv_node = node
assert len(conv_node.outputs[0].uses()) == 1
for usage in conv_node.outputs[0].uses():
if usage.node.op_type == "Sub":
sub_node = usage.node
print(" Fuse Sub into Conv:", sub_node.name)
sub_value = sub_node.inputs[1]
new_bias = (np.negative(sub_value.const_value.numpy())).reshape((-1,))
new_bias_val = ir.val(
f"{sub_value.name}_neg",
const_value=ir.tensor(new_bias),
)
m.graph.initializers.add(new_bias_val)
if len(conv_node.inputs) == 2:
# Bad access of private field
conv_node._inputs = conv_node._inputs + (None,)
conv_node.replace_input_with(2, new_bias_val)
sub_node.outputs[0].replace_all_uses_with(conv_node.outputs[0])
# Clean up any unused nodes
onnx_ir.passes.common.RemoveUnusedNodesPass()(m)
# Clear all intermediate shapes and re-infer shapes
for node in m.graph:
for output in node.outputs:
if output.is_graph_output():
continue
output.shape = None
m.graph.inputs[0].shape = ir.Shape(["batch", *m.graph.inputs[0].shape[1:]])
for output in m.graph.outputs:
output.shape = ir.Shape(["batch", *output.shape[1:]])
onnxscript.optimizer.optimize(
m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(m)
# Rename IO and match the tflite model
m.graph.inputs[0].name = "inputs"
m.graph.outputs[0].name = "spatial_embedding"
m.graph.outputs[1].name = "embedding"
m.graph.outputs[2].name = "spectrogram"
m.graph.outputs[3].name = "label"
out_0 = m.graph.outputs[0]
out_1 = m.graph.outputs[1]
m.graph.outputs[1] = out_0
m.graph.outputs[0] = out_1
ir.save(m, "perch_v2_opt4.onnx")
Author
"""Cleanup and optimize perch_v2_slim.onnx model.
This script can be applied after completing these steps:
1. Use `tf2onnx` to convert the tflite model to onnx
2. Apply onnxslim and onnxscript.optimize.optimizer on the model
3. Manually edit the model to remove the first DFT node (no-op) and fuse
the nodes that effectively takes the magnitude of the DFT output with ReduceL2.
"""
import onnx_ir as ir
import onnx_ir.passes.common
import onnxscript
import numpy as np
m = ir.load("perch_v2_slim.onnx")
for node in m.graph:
if node.op_type == "MatMul":
print("Simplify MatMul + Reshape:", node.name)
if node.inputs[0].producer().op_type == "Reshape":
# Skip the reshape
input = node.inputs[0].producer().inputs[0]
node.replace_input_with(0, input)
for usage in node.outputs[0].uses():
if usage.node.op_type == "Reshape":
reshape_usages = list(usage.node.outputs[0].uses())
# Keep the last Reshape
if reshape_usages[0].node.op_type == "ReduceMax":
shape = ir.val(
"reshape_shape", const_value=ir.tensor([-1, 16, 4, 14795, 4])
)
m.graph.initializers.add(shape)
usage.node.replace_input_with(1, shape)
continue
reshape_node = usage.node
output = reshape_node.outputs[0]
output.replace_all_uses_with(node.outputs[0])
# Remove Expand
if node.op_type == "Expand":
print("Remove Expand:", node.name)
input = node.inputs[0]
output = node.outputs[0]
output.replace_all_uses_with(input)
# Clean up any unused nodes
onnx_ir.passes.common.RemoveUnusedNodesPass()(m)
# Do some const folding
onnxscript.optimizer.optimize(
m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
one_1d = ir.val("1d_one", const_value=ir.tensor([1], dtype=ir.DataType.INT64))
m.graph.initializers.add(one_1d)
# Simplify Unsqueeze + Reshape
for node in m.graph:
if node.op_type == "Reshape":
print("Simplify Unsqueeze + Reshape:", node.name)
if (
node.inputs[0].producer()
and node.inputs[0].producer().op_type == "Unsqueeze"
):
unsqueeze_node = node.inputs[0].producer()
unsqueeze_node.replace_input_with(1, one_1d)
node.outputs[0].replace_all_uses_with(unsqueeze_node.outputs[0])
unsqueeze_node.outputs[0].shape = ir.Shape(["batch", 160000, 1])
first_reshape_shape = ir.val(
"first_reshape_shape", const_value=ir.tensor([-1, 1, 160000, 1])
)
m.graph.initializers.add(first_reshape_shape)
# Simplify first Reshape + Unsqueeze
for node in m.graph:
if node.op_type == "Unsqueeze":
print("Simplify Reshape + Unsqueeze:", node.name)
if node.inputs[0].producer() and node.inputs[0].producer().op_type == "Reshape":
reshape_node = node.inputs[0].producer()
reshape_node.replace_input_with(1, first_reshape_shape)
node.outputs[0].replace_all_uses_with(reshape_node.outputs[0])
reshape_node.outputs[0].shape = ir.Shape(["batch", 1, 160000, 1])
break
# Fuse Conv + Sub into Conv
for node in m.graph:
if node.op_type == "Conv":
print("Check Conv for fusion:", node.name)
conv_node = node
assert len(conv_node.outputs[0].uses()) == 1
for usage in conv_node.outputs[0].uses():
if usage.node.op_type == "Sub":
sub_node = usage.node
print(" Fuse Sub into Conv:", sub_node.name)
sub_value = sub_node.inputs[1]
new_bias = (np.negative(sub_value.const_value.numpy())).reshape((-1,))
new_bias_val = ir.val(
f"{sub_value.name}_neg",
const_value=ir.tensor(new_bias),
)
m.graph.initializers.add(new_bias_val)
if len(conv_node.inputs) == 2:
# Bad access of private field
conv_node._inputs = conv_node._inputs + (None,)
conv_node.replace_input_with(2, new_bias_val)
sub_node.outputs[0].replace_all_uses_with(conv_node.outputs[0])
# Clean up any unused nodes
onnx_ir.passes.common.RemoveUnusedNodesPass()(m)
# Clear all intermediate shapes and re-infer shapes
for node in m.graph:
for output in node.outputs:
if output.is_graph_output():
continue
output.shape = None
m.graph.inputs[0].shape = ir.Shape(["batch", *m.graph.inputs[0].shape[1:]])
for output in m.graph.outputs:
output.shape = ir.Shape(["batch", *output.shape[1:]])
onnxscript.optimizer.optimize(
m, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024
)
onnx_ir.passes.common.ClearMetadataAndDocStringPass()(m)
# Replace None dim with "batch"
for node in m.graph:
for output in node.outputs:
if output.shape is None:
continue
shape = ir.Shape(output.shape)
for i in range(len(shape)):
dim = shape[i]
if isinstance(dim, ir.SymbolicDim) and dim.value is None:
shape[i] = ir.SymbolicDim("batch")
output.shape = shape
# Rename IO and match the tflite model
m.graph.inputs[0].name = "inputs"
m.graph.outputs[0].name = "spatial_embedding"
m.graph.outputs[1].name = "embedding"
m.graph.outputs[2].name = "spectrogram"
m.graph.outputs[3].name = "label"
out_0 = m.graph.outputs[0]
out_1 = m.graph.outputs[1]
m.graph.outputs[1] = out_0
m.graph.outputs[0] = out_1
m.producer_name = "onnx-ir"
m.producer_version = None
m.ir_version = 10
ir.save(m, "perch_v2.onnx")
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment