Last active
January 2, 2025 17:56
-
-
Save ashnair1/c792aff8cf74c98c60b92da8f5fe25ef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import matplotlib.pyplot as plt | |
| import torch | |
| import torchvision.transforms.functional as F | |
| import kornia.augmentation as K | |
| import kornia.io as Kio | |
| # Note: Regarding dict keys and/or data keys for boxes in AugmentationSequential. | |
| # If using 'bbox', kornia assumes boxes are in 'vertices_plus' mode (i.e. (N,4,2) tensor). | |
| # If using 'bbox_xyxy', kornia assumes boxes are in 'xyxy' mode (i.e. (N,4) tensor). | |
| def run_augs(image, bbox, dict_mode=False): | |
| sample = {"input": image, "bbox_xyxy": bbox} | |
| data_keys = None if dict_mode else ["input", "bbox_xyxy"] | |
| # Define the augmentation pipeline | |
| augmentations = K.AugmentationSequential( | |
| K.RandomVerticalFlip(1.0), | |
| #K.RandomHorizontalFlip(1.0), | |
| #K.RandomRotation(45.0, p=1.0), | |
| data_keys=data_keys | |
| ) | |
| if dict_mode: | |
| aug_sample = augmentations(sample) | |
| aug_image, aug_bbox = aug_sample["input"], aug_sample["bbox_xyxy"] | |
| else: | |
| aug_image, aug_bbox = augmentations(image, bbox) | |
| return aug_image, aug_bbox | |
| def main(image, bbox_tensor, dict_mode=False): | |
| bbox = bbox_tensor | |
| aug_image, aug_bbox = run_augs(image, bbox, dict_mode) | |
| aug_bbox_tensor = aug_bbox | |
| images = [image, aug_image] | |
| boxes = [bbox_tensor, aug_bbox_tensor] | |
| titles = ["Original", "Augmented"] | |
| colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange'] | |
| _, axs = plt.subplots(1, 2, figsize=(16, 16)) | |
| for i in range(2): | |
| axs[i].axis("off") | |
| axs[i].imshow(F.to_pil_image(images[i].squeeze(0))) | |
| for b, box in enumerate(boxes[i]): | |
| x_min, y_min, x_max, y_max = box | |
| color = colors[b % len(colors)] | |
| axs[i].scatter([x_min, x_max, x_min, x_max], [y_min, y_min, y_max, y_max], color=color) | |
| for coord in [(x_min, y_min), (x_max, y_min), (x_min, y_max), (x_max, y_max)]: | |
| axs[i].text(coord[0], | |
| coord[1], | |
| f'({coord[0]:.0f}, {coord[1]:.0f})', | |
| color="black", | |
| fontsize=10, | |
| weight='bold') | |
| axs[i].add_patch(plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, | |
| fill=False, edgecolor=color, linewidth=2)) | |
| axs[i].set_title(titles[i]) | |
| plt.show() | |
| if __name__ == '__main__': | |
| image = torch.rand(1, 3, 3, 3) | |
| bbox = torch.tensor([[0.0, 0.0, 2.0, 2.0], [0.0, 1.0, 1.0, 2.0], [1.0, 2.0, 3.0, 3.0]], dtype=torch.float) | |
| # # Create a random image tensor | |
| # image = torch.rand(1, 3, 224, 224) # Batch size of 1, 3 color channels, 224x224 image | |
| # # Define multiple bounding boxes | |
| # bbox_tensor = torch.tensor([ | |
| # [50, 50, 100, 100], | |
| # [100, 100, 150, 150]], | |
| # dtype=torch.float32) # Multiple bboxes for the image | |
| # image = Kio.load_image("panda.jpg", Kio.ImageLoadType.RGB32)[None, ...] | |
| # bbox = torch.tensor( | |
| # [ | |
| # [100, 100, 400, 400], | |
| # [350, 350, 450, 450], | |
| # [700, 50, 900, 400], | |
| # [550, 150, 700, 300], | |
| # ], | |
| # dtype=torch.float32) # Multiple bboxes for the image | |
| main(image=image, bbox_tensor=bbox, dict_mode=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment