-
-
Save qinjian623/6aa777037534c1c1dccbb66f832e93b8 to your computer and use it in GitHub Desktop.
| import onnx | |
| import struct | |
| import torch | |
| import torch.nn as nn | |
| import torchvision as tv | |
| import warnings | |
| # enum DataType { | |
| # UNDEFINED = 0; | |
| # // Basic types. | |
| # FLOAT = 1; // float | |
| # UINT8 = 2; // uint8_t | |
| # INT8 = 3; // int8_t | |
| # UINT16 = 4; // uint16_t | |
| # INT16 = 5; // int16_t | |
| # INT32 = 6; // int32_t | |
| # INT64 = 7; // int64_t | |
| # STRING = 8; // string | |
| # BOOL = 9; // bool | |
| # | |
| # // IEEE754 half-precision floating-point format (16 bits wide). | |
| # // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. | |
| # FLOAT16 = 10; | |
| # | |
| # DOUBLE = 11; | |
| # UINT32 = 12; | |
| # UINT64 = 13; | |
| # COMPLEX64 = 14; // complex with float32 real and imaginary components | |
| # COMPLEX128 = 15; // complex with float64 real and imaginary components | |
| # | |
| # // Non-IEEE floating-point format based on IEEE754 single-precision | |
| # // floating-point number truncated to 16 bits. | |
| # // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. | |
| # BFLOAT16 = 16; | |
| # | |
| # // Future extensions go here. | |
| # } | |
| # TODO more types maybe? | |
| data_type_tab = { | |
| 1: ['f', 4], | |
| 2: ['B', 1], | |
| 3: ['b', 1], | |
| 4: ['H', 2], | |
| 5: ['h', 2], | |
| 6: ['i', 4], | |
| 7: ['q', 8], | |
| 10: ['e', 2], | |
| 11: ['d', 8], | |
| 12: ['I', 4], | |
| 13: ['Q', 8] | |
| } | |
| def empty(x): | |
| return x | |
| # TODO pytorch only accepts 2-value list for padding. | |
| def _slim422(l4): | |
| assert len(l4) == 4 | |
| p0, p1 = l4[::2] | |
| if l4[0] == 0: # TODO bad code | |
| p0 = l4[2] // 2 | |
| if l4[2] == 1: | |
| p0 = 1 | |
| if l4[1] == 0: # TODO bad code | |
| p1 = l4[3] // 2 | |
| if l4[3] == 1: | |
| p1 = 1 | |
| return p0, p1 | |
| def _check_attr(attrs, map): | |
| for attr in attrs: | |
| if attr.name not in map: | |
| warnings.warn("Missing {} in parser's attr_map.".format(attr.name)) | |
| def unpack_weights(initializer): | |
| ret = {} | |
| for i in initializer: | |
| name = i.name | |
| dtype = i.data_type | |
| shape = list(i.dims) | |
| if dtype not in data_type_tab: | |
| warnings("This data type {} is not supported yet.".format(dtype)) | |
| fmt, size = data_type_tab[dtype] | |
| if len(i.raw_data) == 0: | |
| if dtype == 1: | |
| data_list = i.float_data | |
| elif dtype == 7: | |
| data_list = i.int64_data | |
| else: | |
| warnings.warn("No-raw-data type {} not supported yet.".format(dtype)) | |
| else: | |
| data_list = struct.unpack('<' + fmt * (len(i.raw_data) // size), i.raw_data) | |
| t = torch.tensor(data_list) | |
| if len(shape) != 0: | |
| t = t.view(*shape) | |
| ret[name] = t | |
| return ret | |
| def rebuild_lrn(node, weights): | |
| # size, alpha = 1e-4, beta = 0.75, k = 1. | |
| rebuild_lrn.lrn_attr_map = { | |
| 'size': 'size', | |
| 'alpha': 'alpha', | |
| 'beta': 'beta', | |
| 'bias': 'k' | |
| } | |
| kwargs = {} | |
| for att in node.attribute: | |
| kwargs[rebuild_lrn.lrn_attr_map[att.name]] = att.f if att.name != 'size' else att.i | |
| return nn.LocalResponseNorm(**kwargs), node.input, node.output | |
| def rebuild_conv(node, weights): | |
| rebuild_conv.conv_attr_map = { | |
| "pads": "padding", | |
| "strides": "stride", | |
| "kernel_shape": "kernel_size", | |
| "group": "groups", | |
| "dilations": "dilation" | |
| } | |
| assert len(node.output) == 1 | |
| with_bias = False | |
| if len(node.input) == 3: | |
| with_bias = True | |
| bias_name = node.input[2] | |
| bias = weights[bias_name] | |
| weight_name = node.input[1] | |
| weight = weights[weight_name] | |
| in_channels = weight.shape[1] | |
| out_channels = weight.shape[0] | |
| kwargs = {} | |
| for att in node.attribute: | |
| kwargs[rebuild_conv.conv_attr_map[att.name]] = list(att.ints) if att.name != 'group' else att.i | |
| if 'padding' in kwargs: | |
| kwargs["padding"] = _slim422(kwargs["padding"]) | |
| groups = 1 if 'groups' not in kwargs else kwargs['groups'] | |
| in_channels *= groups | |
| conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias=with_bias) | |
| conv.weight.data = weight | |
| if with_bias: | |
| conv.bias.data = bias | |
| return conv, node.input[:1], node.output | |
| def rebuild_dropout(node, weights): | |
| ratio = node.attribute[0].f | |
| return nn.Dropout2d(p=ratio), node.input, node.output | |
| def rebuild_batchnormalization(node, weights): | |
| rebuild_batchnormalization.bn_attr_map = { | |
| "epsilon": "eps", | |
| "momentum": "momentum" | |
| } | |
| assert len(node.input) == 5 | |
| assert len(node.output) == 1 | |
| weight = weights[node.input[1]] | |
| bias = weights[node.input[2]] | |
| running_mean = weights[node.input[3]] | |
| running_var = weights[node.input[4]] | |
| dim = weight.shape[0] | |
| kwargs = {} | |
| _check_attr(node.attribute, rebuild_batchnormalization.bn_attr_map) | |
| for att in node.attribute: | |
| if att.name in rebuild_batchnormalization.bn_attr_map: | |
| kwargs[rebuild_batchnormalization.bn_attr_map[att.name]] = att.f | |
| bn = nn.BatchNorm2d(num_features=dim) | |
| bn.weight.data = weight | |
| bn.bias.data = bias | |
| bn.running_mean.data = running_mean | |
| bn.running_var.data = running_var | |
| return bn, node.input[:1], node.output | |
| def rebuild_relu(node, weights): | |
| return nn.ReLU(), node.input, node.output | |
| def rebuild_maxpool(node, weights): | |
| rebuild_maxpool.mp_attr_map = { | |
| "pads": "padding", | |
| "strides": "stride", | |
| "kernel_shape": "kernel_size", | |
| } | |
| kwargs = {} | |
| for att in node.attribute: | |
| kwargs[rebuild_maxpool.mp_attr_map[att.name]] = list(att.ints) | |
| if 'padding' in kwargs: | |
| kwargs["padding"] = _slim422(kwargs["padding"]) | |
| mp = nn.MaxPool2d(**kwargs) | |
| return mp, node.input, node.output | |
| def rebuild_add(node, weights): | |
| def add(a, b): | |
| return a + b | |
| return add, node.input, node.output | |
| def rebuild_globalaveragepool(node, weights): | |
| avg_pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| return avg_pool, node.input, node.output | |
| def rebuild_transpose(node, weights): | |
| perm = node.attribute[0].ints | |
| def transpose(x): | |
| x = x.permute(*perm) | |
| return x | |
| return transpose, node.input, node.output | |
| def rebuild_flatten(node, weights): | |
| if len(node.attribute) == 0: | |
| d = 1 | |
| else: | |
| d = node.attribute[0].i | |
| def flatten(x): | |
| o_shape = [] | |
| for i in range(d): | |
| o_shape.append(x.shape[i]) | |
| o_shape.append(-1) | |
| return x.view(*o_shape) | |
| return flatten, node.input, node.output | |
| def rebuild_gemm(node, weights): | |
| weight = weights[node.input[1]] | |
| bias = weights[node.input[2]] | |
| in_feats = weight.shape[1] | |
| out_feats = weight.shape[0] | |
| linear = nn.Linear(in_features=in_feats, out_features=out_feats) | |
| linear.weight.data = weight | |
| linear.bias.data = bias | |
| return linear, node.input[:1], node.output | |
| def rebuild_concat(node, weights): | |
| dim = node.attribute[0].i | |
| def concat(*inputs): | |
| # for i in inputs: | |
| # print(i.shape) | |
| ret = torch.cat(inputs, dim) | |
| # print(ret.shape) | |
| # exit() | |
| return ret | |
| return concat, node.input, node.output | |
| def rebuild_pad(node, weights): | |
| mode = node.attribute[0].s | |
| pads = list(node.attribute[1].ints) | |
| value = node.attribute[2].f | |
| assert mode == b'constant' # TODO constant only | |
| assert sum(pads[:4]) == 0 # TODO pad2d only | |
| pad = nn.ConstantPad2d(pads[4:], value) | |
| return pad, node.input, node.output | |
| def rebuild_constant(node, weights): | |
| raw_data = node.attribute[0].t.raw_data | |
| data_type = node.attribute[0].t.data_type | |
| fmt, size = data_type_tab[data_type] | |
| data = struct.unpack('<' + fmt * (len(raw_data) // size), raw_data) | |
| if len(data) == 1: | |
| data = data[0] | |
| def constant(): | |
| return torch.tensor(data) | |
| return constant, [], node.output | |
| def rebuild_sum(node, weights): | |
| def sum(*inputs): | |
| ret = inputs[0] | |
| for i in inputs[1:]: | |
| ret += i | |
| return ret | |
| return sum, node.input, node.output | |
| def rebuild_shape(node, weights): | |
| def shape(x): | |
| return torch.tensor(list(x.shape)) | |
| return shape, node.input, node.output | |
| def rebuild_gather(node, weights): | |
| axis = node.attribute[0].i | |
| def gather(x, idx): | |
| return torch.gather(x, axis, idx) | |
| return gather, node.input, node.output | |
| def _nd_unsqueeze(x, dims): | |
| dims = sorted(dims) | |
| for d in dims: | |
| x = torch.unsqueeze(x, dim=d) | |
| return x | |
| def rebuild_unsqueeze(node, weights): | |
| axes = node.attribute[0].ints | |
| def unsqueeze(x): | |
| return _nd_unsqueeze(x, axes) | |
| return unsqueeze, node.input, node.output | |
| def rebuild_mul(node, weights): | |
| def mul(a, b): | |
| return a * b | |
| return mul, node.input, node.output | |
| def rebuild_softmax(node, weights): | |
| def f_softmax(x): | |
| return x.softmax(dim=1, dtype=torch.double).float() | |
| return f_softmax, node.input, node.output | |
| def rebuild_reshape(node, weights): | |
| def reshape(x, s): | |
| data_shape = x.shape | |
| onnx_shape = s.tolist() | |
| pt_shape = [] | |
| for idx, d in enumerate(onnx_shape): | |
| if d == 0: | |
| pt_shape.append(data_shape[idx]) | |
| else: | |
| pt_shape.append(d) | |
| return torch.reshape(x, pt_shape) | |
| return reshape, node.input, node.output | |
| def rebuild_averagepool(node, weights): | |
| rebuild_averagepool.avg_attr_map = { | |
| "pads": "padding", | |
| "strides": "stride", | |
| "kernel_shape": "kernel_size", | |
| } | |
| kwargs = {} | |
| for att in node.attribute: | |
| kwargs[rebuild_averagepool.avg_attr_map[att.name]] = list(att.ints) | |
| if 'padding' in kwargs: | |
| kwargs["padding"] = _slim422(kwargs["padding"]) | |
| ap = nn.AvgPool2d(**kwargs) | |
| return ap, node.input, node.output | |
| def rebuild_op(node, weights): | |
| op_type = node.op_type | |
| return globals()['rebuild_'+op_type.lower()](node, weights) | |
| def construct_pytorch_nodes(graph, weights): | |
| ret = [] | |
| for single_node in graph.node: | |
| ret.append(rebuild_op(single_node, weights)) | |
| return ret | |
| def resolve_deps(name, deps, inter_tensors): | |
| if name in inter_tensors: | |
| return | |
| else: | |
| op, deps_names = deps[name] | |
| args = [] | |
| for deps_name in deps_names: | |
| resolve_deps(deps_name, deps, inter_tensors) | |
| args.append(inter_tensors[deps_name]) | |
| result = op(*args) | |
| inter_tensors[name] = result | |
| class DependencyModule(nn.Module): | |
| def __init__(self, onnx_model, input_name=None): | |
| super(DependencyModule, self).__init__() | |
| self.deps = {} | |
| self.inter_tensors = dict() | |
| self.weights = unpack_weights(onnx_model.graph.initializer) | |
| nodes = construct_pytorch_nodes(onnx_model.graph, self.weights) | |
| for idx, (node, inputs, outputs) in enumerate(nodes): | |
| if isinstance(node, nn.Module): | |
| self.add_module(str(idx), node) | |
| for output_name in outputs: | |
| self.deps[output_name] = (node, inputs) | |
| self.input_name = onnx_model.graph.input[0].name # TODO only you | |
| self.output_name = onnx_model.graph.output[0].name # TODO only you | |
| if input_name is not None: | |
| self.input_name = input_name | |
| def forward(self, input): | |
| self.inter_tensors = self.weights.copy() | |
| self.inter_tensors[self.input_name] = input | |
| resolve_deps(self.output_name, self.deps, self.inter_tensors) | |
| return self.inter_tensors[self.output_name] | |
| def test_net(original_model, onnx_file): | |
| import time | |
| original_model.eval() | |
| onnx_model = onnx.load(onnx_file) | |
| reconstruct_model = DependencyModule(onnx_model) | |
| reconstruct_model.eval() | |
| input = torch.randn(3, 3, 224, 224) | |
| s = time.time() | |
| r1 = original_model(input) | |
| print("Original:", time.time() - s) | |
| s = time.time() | |
| r = reconstruct_model(input) | |
| print("DependencyModule:", time.time() - s) | |
| print("Max error for", onnx_file, ":", (r - r1).abs().max().item()) | |
| def main(): | |
| test_net(tv.models.resnet18(True), "res18.onnx") | |
| test_net(tv.models.resnet50(True), "res50.onnx") | |
| test_net(tv.models.densenet121(True), "dense121.onnx") | |
| if __name__ == '__main__': | |
| main() |
| import mxnet.contrib.onnx as onnx_mxnet | |
| import mxnet as mx | |
| import numpy as np | |
| import torch | |
| import onnx | |
| import onnx2pytorch as oi | |
| from collections import namedtuple | |
| def construct_mxnext_model(onnx_file, test_input): | |
| sym, arg, aux = onnx_mxnet.import_model(onnx_file) | |
| data_names = [graph_input for graph_input in sym.list_inputs() | |
| if graph_input not in arg and graph_input not in aux] | |
| print("Input Blob Names:", data_names) | |
| mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) | |
| print(sym) | |
| # exit(0) | |
| mod.bind(for_training=False, data_shapes=[(data_names[0], test_input.shape)], label_shapes=None) | |
| mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, allow_extra=True) | |
| Batch = namedtuple('Batch', ['data']) | |
| # forward on the provided data batch | |
| mod.forward(Batch([mx.nd.array(test_input)])) | |
| output = mod.get_outputs()[0] | |
| mo = output.asnumpy() | |
| return mo | |
| def construct_pytorch_model(onnx_file, test_input): | |
| onnx_model = onnx.load(onnx_file) | |
| if onnx_file == "densenet121.onnx": | |
| reconstruct_model = oi.DependencyModule(onnx_model, input_name="data_0") | |
| else: | |
| reconstruct_model = oi.DependencyModule(onnx_model) | |
| reconstruct_model.eval() | |
| i = torch.from_numpy(test_input).float() | |
| o = reconstruct_model(i).detach().numpy() | |
| return o | |
| def test_onnx_model(onnx_file): | |
| print("=" * 80) | |
| print(onnx_file, ":") | |
| test_input = np.random.randn(1, 3, 224, 224) / 10 | |
| o = construct_pytorch_model(onnx_file, test_input) | |
| mo = construct_mxnext_model(onnx_file, test_input) | |
| abs_error = np.absolute(mo - o) | |
| print(abs_error.max(), abs_error.mean(), abs_error.min()) | |
| print(mo[0][:5]) | |
| print(o[0][:5]) | |
| def main(): | |
| ok_onnx_model_files = [ | |
| "googlenet.onnx", # OK special padding setting case not supported by PyTorch MaxPool. with Softmax() | |
| "resnet18v2.onnx", # OK | |
| "resnet34v2.onnx", # OK | |
| "squeezenet1.1.onnx", # OK | |
| "mobilenetv2-1.0.onnx", # OK | |
| "alex_net.onnx", # OK but max error is not small enough. with Softmax() | |
| "densenet121.onnx", # OK but input_name is 'data_0', not '0' in onnx.graph.input | |
| "vgg16.onnx", # OK | |
| # "inception_v2.onnx", # TODO wrong output, with Softmax() | |
| # "inception_v1.onnx", # TODO Gemm weight shape in runtime | |
| # "shuffle_net.onnx", # TODO wrong output, maybe by transpose or Softmax() | |
| ] | |
| for model_file in ok_onnx_model_files: | |
| test_onnx_model(model_file) | |
| if __name__ == '__main__': | |
| main() | |
Hi, thank you for publishing excellent source code. What is the license of this code? Can I change and use this for commercial purposes?
Hi, @akawashiro .
Sorry for the late reply.
I always prefer MIT license. And it's compatible with commercial usage.
Thank you.
hi i am sorry but i am new to machine learning but have to make a coinventor from onnx to pytorch so your code is a big help, if its not too much to ask can you elaborate in an example how dose your code works
hi i am sorry but i am new to machine learning but have to make a coinventor from onnx to pytorch so your code is a big help, if its not too much to ask can you elaborate in an example how dose your code works
@MonTer998
Code starts from here showed a simple usage of this script:
https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8#file-onnx2pytorch_validate-py-L29
@qinjian623 thanks a million
Hi, thank you for publishing excellent source code. What is the license of this code? Can I change and use this for commercial purposes?