Skip to content

Instantly share code, notes, and snippets.

@ashnair1
Last active January 2, 2025 17:56
Show Gist options
  • Select an option

  • Save ashnair1/c792aff8cf74c98c60b92da8f5fe25ef to your computer and use it in GitHub Desktop.

Select an option

Save ashnair1/c792aff8cf74c98c60b92da8f5fe25ef to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import torch
import torchvision.transforms.functional as F
import kornia.augmentation as K
import kornia.io as Kio
# Note: Regarding dict keys and/or data keys for boxes in AugmentationSequential.
# If using 'bbox', kornia assumes boxes are in 'vertices_plus' mode (i.e. (N,4,2) tensor).
# If using 'bbox_xyxy', kornia assumes boxes are in 'xyxy' mode (i.e. (N,4) tensor).
def run_augs(image, bbox, dict_mode=False):
sample = {"input": image, "bbox_xyxy": bbox}
data_keys = None if dict_mode else ["input", "bbox_xyxy"]
# Define the augmentation pipeline
augmentations = K.AugmentationSequential(
K.RandomVerticalFlip(1.0),
#K.RandomHorizontalFlip(1.0),
#K.RandomRotation(45.0, p=1.0),
data_keys=data_keys
)
if dict_mode:
aug_sample = augmentations(sample)
aug_image, aug_bbox = aug_sample["input"], aug_sample["bbox_xyxy"]
else:
aug_image, aug_bbox = augmentations(image, bbox)
return aug_image, aug_bbox
def main(image, bbox_tensor, dict_mode=False):
bbox = bbox_tensor
aug_image, aug_bbox = run_augs(image, bbox, dict_mode)
aug_bbox_tensor = aug_bbox
images = [image, aug_image]
boxes = [bbox_tensor, aug_bbox_tensor]
titles = ["Original", "Augmented"]
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange']
_, axs = plt.subplots(1, 2, figsize=(16, 16))
for i in range(2):
axs[i].axis("off")
axs[i].imshow(F.to_pil_image(images[i].squeeze(0)))
for b, box in enumerate(boxes[i]):
x_min, y_min, x_max, y_max = box
color = colors[b % len(colors)]
axs[i].scatter([x_min, x_max, x_min, x_max], [y_min, y_min, y_max, y_max], color=color)
for coord in [(x_min, y_min), (x_max, y_min), (x_min, y_max), (x_max, y_max)]:
axs[i].text(coord[0],
coord[1],
f'({coord[0]:.0f}, {coord[1]:.0f})',
color="black",
fontsize=10,
weight='bold')
axs[i].add_patch(plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
fill=False, edgecolor=color, linewidth=2))
axs[i].set_title(titles[i])
plt.show()
if __name__ == '__main__':
image = torch.rand(1, 3, 3, 3)
bbox = torch.tensor([[0.0, 0.0, 2.0, 2.0], [0.0, 1.0, 1.0, 2.0], [1.0, 2.0, 3.0, 3.0]], dtype=torch.float)
# # Create a random image tensor
# image = torch.rand(1, 3, 224, 224) # Batch size of 1, 3 color channels, 224x224 image
# # Define multiple bounding boxes
# bbox_tensor = torch.tensor([
# [50, 50, 100, 100],
# [100, 100, 150, 150]],
# dtype=torch.float32) # Multiple bboxes for the image
# image = Kio.load_image("panda.jpg", Kio.ImageLoadType.RGB32)[None, ...]
# bbox = torch.tensor(
# [
# [100, 100, 400, 400],
# [350, 350, 450, 450],
# [700, 50, 900, 400],
# [550, 150, 700, 300],
# ],
# dtype=torch.float32) # Multiple bboxes for the image
main(image=image, bbox_tensor=bbox, dict_mode=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment