Created
August 17, 2021 06:48
-
-
Save nprithviraj24/a0f7930863859bdce327aa6707ed9abb to your computer and use it in GitHub Desktop.
This script accepts labelme's json filepath, and creates a segmentation mask.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import sys | |
| import json | |
| import numpy as np | |
| import PIL.Image | |
| import PIL.ImageDraw | |
| import cv2 | |
| def shape_to_mask( | |
| img_shape, points, shape_type=None, line_width=10, point_size=5, mask=None | |
| ): | |
| if mask is None: | |
| mask = np.zeros(img_shape[:2], dtype=np.uint8) | |
| mask = PIL.Image.fromarray(mask) | |
| draw = PIL.ImageDraw.Draw(mask) | |
| xy = [tuple(point) for point in points] | |
| if shape_type == "circle": | |
| assert len(xy) == 2, "Shape of shape_type=circle must have 2 points" | |
| (cx, cy), (px, py) = xy | |
| d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2) | |
| draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1) | |
| elif shape_type == "rectangle": | |
| assert len(xy) == 2, "Shape of shape_type=rectangle must have 2 points" | |
| draw.rectangle(xy, outline=1, fill=1) | |
| elif shape_type == "line": | |
| assert len(xy) == 2, "Shape of shape_type=line must have 2 points" | |
| draw.line(xy=xy, fill=1, width=line_width) | |
| elif shape_type == "linestrip": | |
| draw.line(xy=xy, fill=1, width=line_width) | |
| elif shape_type == "point": | |
| assert len(xy) == 1, "Shape of shape_type=point must have 1 points" | |
| cx, cy = xy[0] | |
| r = point_size | |
| draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1) | |
| else: | |
| assert len(xy) > 2, "Polygon must have points more than 2" | |
| draw.polygon(xy=xy, outline=1, fill=1) | |
| mask = np.array(mask, dtype=bool) | |
| mask = mask.astype(np.uint8)#boolean to 0,Convert to 1 | |
| mask[mask>0] = 255 | |
| return mask | |
| with open( sys.argv[1], "r",encoding="utf-8") as f: | |
| dj = json.load(f) | |
| mask = np.zeros((dj['imageHeight'],dj['imageWidth']), dtype=np.uint8) | |
| for i, _ in enumerate(dj['shapes']): | |
| mask = shape_to_mask( (dj['imageHeight'],dj['imageWidth']), dj['shapes'][i]['points'], shape_type=None,line_width= 0.2, point_size=0.5, mask=mask) | |
| #mask_img = mask.astype(np.uint8)#boolean to 0,Convert to 1 | |
| #mask_img[mask_img>0] = 255 | |
| cv2.imwrite(f'{sys.argv[1].split(".")[0]}_mask.png', mask) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment