Skip to content

Instantly share code, notes, and snippets.

@Borda
Last active July 30, 2024 06:23
Show Gist options
  • Select an option

  • Save Borda/e62d7261d20a6c2c937c1418a0135295 to your computer and use it in GitHub Desktop.

Select an option

Save Borda/e62d7261d20a6c2c937c1418a0135295 to your computer and use it in GitHub Desktop.
Object detection - WIDERFace
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models, datasets, ops
from torchvision.transforms import v2 as transforms
import pytorch_lightning as pl
# Step 1a: Define the transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize(size=(800,), max_size=1333),
])
# Step 1b: Conver dataset item to accepted target struture
def convert_inputs(imgs, annot, device, small_thr=0.001):
images, targets = [], []
for img, annot in zip(imgs, annot):
bbox = annot['bbox']
small = (bbox[:, 2] * bbox[:, 3]) <= (img.size[1] * img.size[0] * small_thr)
boxes = ops.box_convert(bbox[~small], in_fmt='xywh', out_fmt='xyxy')
output_dict = transform({"image": img, "boxes": boxes})
images.append(output_dict['image'].to(device))
targets.append({
'boxes': output_dict['boxes'].to(device),
'labels': torch.ones(len(boxes), dtype=int, device=device)
})
return images, targets
# Step 2: Use a pretrained Faster R-CNN model from torchvision and modify it
class FaceDetectionModel(pl.LightningModule):
def __init__(self):
super(FaceDetectionModel, self).__init__()
self.model = models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights="DEFAULT")
def forward(self, images, targets=None):
if targets is None:
return self.model(images)
return self.model(images, targets)
def training_step(self, batch, batch_idx):
imgs, annot = batch
images, targets = convert_inputs(imgs, annot, device=self.device)
loss_dict = self.model(images, targets)
losses = sum(loss for loss in loss_dict.values())
self.log('train_loss', losses)
return losses
def configure_optimizers(self):
return optim.SGD(self.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
# Step 3: Define a collate function to handle batches
def collate_fn(batch):
return tuple(zip(*batch))
if __name__ == "__main__":
# Step 4: Load the WIDERFace dataset using torchvision.datasets
train_dataset = datasets.WIDERFace(root='./data', split='train', download=True)
# Step 5: Set up the DataLoader and train the model
train_loader = DataLoader(
train_dataset, batch_size=8, shuffle=True, num_workers=4, collate_fn=collate_fn
)
# Step 6: Create and Train the model
model = FaceDetectionModel()
trainer = pl.Trainer(max_epochs=5, precision='16-mixed', log_every_n_steps=10)
trainer.fit(model, train_dataloaders=train_loader)
import sys
import glob
import torch
import numpy as np
import torchvision.transforms.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torchvision import datasets, ops, utils
from torchvision.transforms import v2 as transforms
from face_detection import FaceDetectionModel, convert_inputs, transform
plt.rcParams["savefig.bbox"] = "tight"
sample_idx = int(sys.argv[1]) if len(sys.argv) >= 2 else 0
print(f"selected image sample: {sample_idx}")
def show(imgs):
if not isinstance(imgs, list):
imgs = [imgs]
fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(7 * len(imgs), 8))
for i, img in enumerate(imgs):
img = img.detach()
img = F.to_pil_image(img)
axs[0, i].imshow(np.asarray(img))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
return fig
# Step 1: Define the transform
transform = transforms.Compose([transforms.ToTensor()])
# define the transform
normalize = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize(size=(800,), max_size=1333),
])
# Load the WIDERFace dataset using torchvision.datasets
train_dataset = datasets.WIDERFace(root='./data', split='train', download=True, transform=transform)
img, target = train_dataset[sample_idx]
img = F.convert_image_dtype(img, dtype=torch.uint8)
boxes = ops.box_convert(target['bbox'], in_fmt='xywh', out_fmt='xyxy')
# visualize the annotation
annot = utils.draw_bounding_boxes(img, boxes, colors="red", width=5)
# Replace with path to your trained checkpoint 'lightning_logs/version_x/checkpoints/...'
checkpoint_path = glob.glob("lightning_logs/version_6/checkpoints/*.ckpt")[0]
print(f"loading model from checkpoint '{checkpoint_path}'")
# Load the model
model = FaceDetectionModel.load_from_checkpoint(checkpoint_path).cpu()
model.eval()
# Get the model prediction
img2, _ = train_dataset[sample_idx]
with torch.no_grad():
output = model.model([normalize(img2)])
print(f"predistions: {output}")
boxes = output[0]['boxes'][output[0]['scores'] >= 0.5]
# visualize the predictions
preds = utils.draw_bounding_boxes(img, boxes, colors="green", width=5)
# export figure
fig = show([annot, preds])
fig.savefig('figure.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment