Last active
July 30, 2024 06:23
-
-
Save Borda/e62d7261d20a6c2c937c1418a0135295 to your computer and use it in GitHub Desktop.
Object detection - WIDERFace
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from torchvision import models, datasets, ops | |
| from torchvision.transforms import v2 as transforms | |
| import pytorch_lightning as pl | |
| # Step 1a: Define the transform | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| transforms.Resize(size=(800,), max_size=1333), | |
| ]) | |
| # Step 1b: Conver dataset item to accepted target struture | |
| def convert_inputs(imgs, annot, device, small_thr=0.001): | |
| images, targets = [], [] | |
| for img, annot in zip(imgs, annot): | |
| bbox = annot['bbox'] | |
| small = (bbox[:, 2] * bbox[:, 3]) <= (img.size[1] * img.size[0] * small_thr) | |
| boxes = ops.box_convert(bbox[~small], in_fmt='xywh', out_fmt='xyxy') | |
| output_dict = transform({"image": img, "boxes": boxes}) | |
| images.append(output_dict['image'].to(device)) | |
| targets.append({ | |
| 'boxes': output_dict['boxes'].to(device), | |
| 'labels': torch.ones(len(boxes), dtype=int, device=device) | |
| }) | |
| return images, targets | |
| # Step 2: Use a pretrained Faster R-CNN model from torchvision and modify it | |
| class FaceDetectionModel(pl.LightningModule): | |
| def __init__(self): | |
| super(FaceDetectionModel, self).__init__() | |
| self.model = models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights="DEFAULT") | |
| def forward(self, images, targets=None): | |
| if targets is None: | |
| return self.model(images) | |
| return self.model(images, targets) | |
| def training_step(self, batch, batch_idx): | |
| imgs, annot = batch | |
| images, targets = convert_inputs(imgs, annot, device=self.device) | |
| loss_dict = self.model(images, targets) | |
| losses = sum(loss for loss in loss_dict.values()) | |
| self.log('train_loss', losses) | |
| return losses | |
| def configure_optimizers(self): | |
| return optim.SGD(self.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005) | |
| # Step 3: Define a collate function to handle batches | |
| def collate_fn(batch): | |
| return tuple(zip(*batch)) | |
| if __name__ == "__main__": | |
| # Step 4: Load the WIDERFace dataset using torchvision.datasets | |
| train_dataset = datasets.WIDERFace(root='./data', split='train', download=True) | |
| # Step 5: Set up the DataLoader and train the model | |
| train_loader = DataLoader( | |
| train_dataset, batch_size=8, shuffle=True, num_workers=4, collate_fn=collate_fn | |
| ) | |
| # Step 6: Create and Train the model | |
| model = FaceDetectionModel() | |
| trainer = pl.Trainer(max_epochs=5, precision='16-mixed', log_every_n_steps=10) | |
| trainer.fit(model, train_dataloaders=train_loader) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import sys | |
| import glob | |
| import torch | |
| import numpy as np | |
| import torchvision.transforms.functional as F | |
| import pytorch_lightning as pl | |
| import matplotlib.pyplot as plt | |
| from torchvision import datasets, ops, utils | |
| from torchvision.transforms import v2 as transforms | |
| from face_detection import FaceDetectionModel, convert_inputs, transform | |
| plt.rcParams["savefig.bbox"] = "tight" | |
| sample_idx = int(sys.argv[1]) if len(sys.argv) >= 2 else 0 | |
| print(f"selected image sample: {sample_idx}") | |
| def show(imgs): | |
| if not isinstance(imgs, list): | |
| imgs = [imgs] | |
| fig, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(7 * len(imgs), 8)) | |
| for i, img in enumerate(imgs): | |
| img = img.detach() | |
| img = F.to_pil_image(img) | |
| axs[0, i].imshow(np.asarray(img)) | |
| axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | |
| return fig | |
| # Step 1: Define the transform | |
| transform = transforms.Compose([transforms.ToTensor()]) | |
| # define the transform | |
| normalize = transforms.Compose([ | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| transforms.Resize(size=(800,), max_size=1333), | |
| ]) | |
| # Load the WIDERFace dataset using torchvision.datasets | |
| train_dataset = datasets.WIDERFace(root='./data', split='train', download=True, transform=transform) | |
| img, target = train_dataset[sample_idx] | |
| img = F.convert_image_dtype(img, dtype=torch.uint8) | |
| boxes = ops.box_convert(target['bbox'], in_fmt='xywh', out_fmt='xyxy') | |
| # visualize the annotation | |
| annot = utils.draw_bounding_boxes(img, boxes, colors="red", width=5) | |
| # Replace with path to your trained checkpoint 'lightning_logs/version_x/checkpoints/...' | |
| checkpoint_path = glob.glob("lightning_logs/version_6/checkpoints/*.ckpt")[0] | |
| print(f"loading model from checkpoint '{checkpoint_path}'") | |
| # Load the model | |
| model = FaceDetectionModel.load_from_checkpoint(checkpoint_path).cpu() | |
| model.eval() | |
| # Get the model prediction | |
| img2, _ = train_dataset[sample_idx] | |
| with torch.no_grad(): | |
| output = model.model([normalize(img2)]) | |
| print(f"predistions: {output}") | |
| boxes = output[0]['boxes'][output[0]['scores'] >= 0.5] | |
| # visualize the predictions | |
| preds = utils.draw_bounding_boxes(img, boxes, colors="green", width=5) | |
| # export figure | |
| fig = show([annot, preds]) | |
| fig.savefig('figure.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment