-
-
Save nuzrub/ee3dc19242915278e95cb75014e29083 to your computer and use it in GitHub Desktop.
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import tensorflow as tf | |
| import numpy as np | |
| import os | |
| from object_detection.utils import label_map_util | |
| from object_detection.utils import config_util | |
| from object_detection.utils import visualization_utils as viz_utils | |
| from object_detection.builders import model_builder | |
| center_net_path = './centernet_resnet50_v1_fpn_512x512_coco17_tpu-8/' | |
| pipeline_config = center_net_path + 'pipeline.config' | |
| model_path = center_net_path + 'checkpoint/' | |
| label_map_path = './mscoco_label_map.pbtxt' | |
| image_path = './test.jpg' | |
| # Load pipeline config and build a detection model | |
| configs = config_util.get_configs_from_pipeline_file(pipeline_config) | |
| model_config = configs['model'] | |
| detection_model = model_builder.build(model_config=model_config, is_training=False) | |
| # Restore checkpoint | |
| ckpt = tf.compat.v2.train.Checkpoint(model=detection_model) | |
| ckpt.restore(os.path.join(model_path, 'ckpt-0')).expect_partial() | |
| def get_model_detection_function(model): | |
| @tf.function | |
| def detect_fn(image): | |
| image, shapes = model.preprocess(image) | |
| prediction_dict = model.predict(image, shapes) | |
| detections = model.postprocess(prediction_dict, shapes) | |
| return detections, prediction_dict, tf.reshape(shapes, [-1]) | |
| return detect_fn | |
| detect_fn = get_model_detection_function(detection_model) | |
| label_map_path = label_map_path | |
| label_map = label_map_util.load_labelmap(label_map_path) | |
| categories = label_map_util.convert_label_map_to_categories( | |
| label_map, | |
| max_num_classes=label_map_util.get_max_label_map_index(label_map), | |
| use_display_name=True) | |
| category_index = label_map_util.create_category_index(categories) | |
| label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True) | |
| image = np.array(Image.open(image_path)) | |
| input_tensor = tf.convert_to_tensor(np.expand_dims(image, 0), dtype=tf.float32) | |
| detections, predictions_dict, shapes = detect_fn(input_tensor) | |
| label_id_offset = 1 | |
| image_np_with_detections = image.copy() | |
| # Use keypoints if available in detections | |
| keypoints, keypoint_scores = None, None | |
| if 'detection_keypoints' in detections: | |
| keypoints = detections['detection_keypoints'][0].numpy() | |
| keypoint_scores = detections['detection_keypoint_scores'][0].numpy() | |
| def get_keypoint_tuples(eval_config): | |
| tuple_list = [] | |
| kp_list = eval_config.keypoint_edge | |
| for edge in kp_list: | |
| tuple_list.append((edge.start, edge.end)) | |
| return tuple_list | |
| viz_utils.visualize_boxes_and_labels_on_image_array( | |
| image_np_with_detections, | |
| detections['detection_boxes'][0].numpy(), | |
| (detections['detection_classes'][0].numpy() + label_id_offset).astype(int), | |
| detections['detection_scores'][0].numpy(), | |
| category_index, | |
| use_normalized_coordinates=True, | |
| max_boxes_to_draw=200, | |
| min_score_thresh=.30, | |
| agnostic_mode=False, | |
| keypoints=keypoints, | |
| keypoint_scores=keypoint_scores, | |
| keypoint_edges=get_keypoint_tuples(configs['eval_config'])) | |
| plt.figure(figsize=(12,16)) | |
| plt.imshow(image_np_with_detections) | |
| plt.savefig('./output.png') | |
| plt.show() |
nuzrub
commented
Dec 1, 2020
via email
I'm trying to get the object predictions before inference the class of every region. I think that with your code i can get my goal but i'm not sure
I need to get the elements or regions before the class inference, i mean all the regions proposed by the model before to the class detection in every one of them. I need to access to al regions proposed to make some changes or insert new regions and then get the class of every one of them.
I will check it. I think that could work.
Thanks for help me nuzrub ; )
in this case I am working with centernet but I will take it into account in case I need other alternative models.
First, I would like to say thank you for the easiest to follow tutorial on Medium/TowardsDataScience.
In Line 18 label_map_path = './coco_labelmap.pbtxt' should be '...mscoco_label_map.pbtxt'. The label map file must have been updated.
First, I would like to say thank you for the easiest to follow tutorial on Medium/TowardsDataScience.
In Line 18 label_map_path = './coco_labelmap.pbtxt' should be '...mscoco_label_map.pbtxt'. The label map file must have been updated.
Thanks for enjoying it. I fixed the label map path, thanks :)