-
-
Save zongfan2/5eebd085fff4aa0267e0132046b80437 to your computer and use it in GitHub Desktop.
| import tensorflow as tf | |
| from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2 | |
| from tensorflow.tools.graph_transforms import TransformGraph | |
| from google.protobuf import text_format | |
| import numpy as np | |
| # object detection api input and output nodes | |
| input_name = "image_tensor" | |
| output_names = ["detection_boxes", "detection_classes", "detection_scores", "num_detections"] | |
| # Const should be float32 in object detection api during nms (see here: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v4.html) | |
| keep_fp32_node_name = ["Postprocessor/BatchMultiClassNonMaxSuppression/MultiClassNonMaxSuppression/non_max_suppression/iou_threshold", | |
| "Postprocessor/BatchMultiClassNonMaxSuppression/MultiClassNonMaxSuppression/non_max_suppression/score_threshold"] | |
| def load_graph(model_path): | |
| graph = tf.Graph() | |
| with graph.as_default(): | |
| graph_def = tf.GraphDef() | |
| if model_path.endswith("pb"): | |
| with open(model_path, "rb") as f: | |
| graph_def.ParseFromString(f.read()) | |
| else: | |
| with open(model_path, "r") as pf: | |
| text_format.Parse(pf.read(), graph_def) | |
| tf.import_graph_def(graph_def, name="") | |
| sess = tf.Session(graph=graph) | |
| return sess | |
| def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'): | |
| """ | |
| Rewrite FusedBatchNorm with FusedBatchNormV2 for reserve_space_1 and reserve_space_2 in FusedBatchNorm require float32 for | |
| gradient calculation (See here: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fused-batch-norm) | |
| """ | |
| if target_type == 'fp16': | |
| dtype = types_pb2.DT_HALF | |
| elif target_type == 'fp64': | |
| dtype = types_pb2.DT_DOUBLE | |
| else: | |
| dtype = types_pb2.DT_FLOAT | |
| new_node = graph_def.node.add() | |
| new_node.op = "FusedBatchNormV2" | |
| new_node.name = node.name | |
| new_node.input.extend(node.input) | |
| new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT)) | |
| for attr in list(node.attr.keys()): | |
| if attr == "T": | |
| node.attr[attr].type = dtype | |
| new_node.attr[attr].CopyFrom(node.attr[attr]) | |
| print("rewrite fused_batch_norm done!") | |
| def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None): | |
| if target_type == 'fp16': | |
| dtype = types_pb2.DT_HALF | |
| elif target_type == 'fp64': | |
| dtype = types_pb2.DT_DOUBLE | |
| else: | |
| dtype = types_pb2.DT_FLOAT | |
| source_sess = load_graph(model_path) | |
| source_graph_def = source_sess.graph.as_graph_def() | |
| target_graph_def = graph_pb2.GraphDef() | |
| target_graph_def.versions.CopyFrom(source_graph_def.versions) | |
| for node in source_graph_def.node: | |
| # fused batch norm node | |
| if node.op == "FusedBatchNorm": | |
| rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type) | |
| continue | |
| # replicate node | |
| new_node = target_graph_def.node.add() | |
| new_node.op = node.op | |
| new_node.name = node.name | |
| new_node.input.extend(node.input) | |
| attrs = list(node.attr.keys()) | |
| # keep batch norm params node | |
| if ("BatchNorm" in node.name) or ('batch_normalization' in node.name): | |
| for attr in attrs: | |
| new_node.attr[attr].CopyFrom(node.attr[attr]) | |
| continue | |
| # replace dtype in node attr with target dtype | |
| for attr in attrs: | |
| # keep special node in fp32 | |
| if node.name in keep_fp32_node_name: | |
| new_node.attr[attr].CopyFrom(node.attr[attr]) | |
| continue | |
| if node.attr[attr].type == types_pb2.DT_FLOAT: | |
| # modify node dtype | |
| new_node.attr[attr].type = dtype | |
| if attr == "value": | |
| tensor = node.attr[attr].tensor | |
| if tensor.dtype == types_pb2.DT_FLOAT: | |
| # if float_val exists | |
| if tensor.float_val: | |
| float_val = tf.make_ndarray(node.attr[attr].tensor) | |
| new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype)) | |
| continue | |
| # if tensor content exists | |
| if tensor.tensor_content: | |
| tensor_shape = [x.size for x in tensor.tensor_shape.dim] | |
| tensor_weights = tf.make_ndarray(tensor) | |
| # reshape tensor | |
| tensor_weights = np.reshape(tensor_weights, tensor_shape) | |
| tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype) | |
| new_node.attr[attr].tensor.CopyFrom(tensor_proto) | |
| continue | |
| new_node.attr[attr].CopyFrom(node.attr[attr]) | |
| # transform graph | |
| if output_names: | |
| if not input_name: | |
| input_name = [] | |
| transforms = ["strip_unused_nodes"] | |
| target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms) | |
| # write graph_def to model | |
| tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text) | |
| print("Converting done ...") | |
| save_path = "test" | |
| name = "test.pb" | |
| as_text = False | |
| target_type = 'fp16' | |
| convert_graph_to_fp16(model_path, save_path, name, as_text=as_text, target_type=target_type, input_name=input_name, output_names=output_names) | |
| # test loading | |
| # ISSUE: loading detection model is extremely slow while loading classification model is normal | |
| sess = load_graph(save_path+"/"+name) |
@glennford49 I'm not sure about this error. Could you provide your model?
I just downloaded the model 20180402-114759.pb trained in vggface2 in david sanberg github, i have used this model without converting to fp16 it works fine.. when i try to convert this using your code replacing the input name with input and ouput name with embeddings, but i get errors
ValueError: Input 0 of node InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/FusedBatchNorm was passed float from InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/FusedBatchNorm/Switch:1 incompatible with expected half.
im trying to convert pretrained model in facenet with input_names='input and output_names='embeddings' ,using tensorflow 1.14.0
same issue here. any idea to solve it ?
some bugs:
the fp16 is not more than tf.float16.max and makes some value be inf.
for the attr "SrcT","T","Tparams","DstT", this code doesn't make any change.
i have copied entirely your code but stil gives me same error, i will try to downgrade tf to 1.13.1