Created
July 29, 2024 23:15
-
-
Save Borda/b419764e3ff28ab46d4e8a087e97b94e to your computer and use it in GitHub Desktop.
Image classification - StanfordCars
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import glob | |
| import torch | |
| import scipy | |
| import random | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from torchvision import transforms | |
| from train import LitClassification, ClassificationData | |
| # load test association to the numerical labels | |
| meta = scipy.io.loadmat('stanford_cars/devkit/cars_meta.mat') | |
| labels = [it[0] for it in meta["class_names"][0]] | |
| # load sample test image | |
| ls_images = glob.glob('stanford_cars/cars_test/*.jpg') | |
| random.shuffle(ls_images) | |
| print(f"found {len(ls_images)} cars in test folder") | |
| image = Image.open(ls_images[0]).convert('RGB') | |
| # Preprocess the input image | |
| preprocess = ClassificationData().transform | |
| tensor = preprocess(image).unsqueeze(0) # Add batch dimension | |
| # Replace with path to your trained checkpoint 'lightning_logs/version_x/checkpoints/...' | |
| checkpoint_path = glob.glob("lightning_logs/version_*/checkpoints/*.ckpt")[0] | |
| print(f"loading model from checkpoint '{checkpoint_path}'") | |
| # Load the model | |
| model = LitClassification.load_from_checkpoint(checkpoint_path).cpu() | |
| model.eval() | |
| # Get the model prediction | |
| with torch.no_grad(): | |
| output = model.model(tensor) | |
| pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy() | |
| print(f"predistion suggest {pred} whic is '{labels[pred]}'") | |
| # show the input images with the label | |
| plt.figure(figsize=(12, 8)) | |
| plt.imshow(image) | |
| plt.title(labels[pred]) | |
| plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from timm import create_model | |
| from torchvision import transforms, datasets | |
| import pytorch_lightning as pl | |
| class LitClassification(pl.LightningModule): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = create_model('resnet50', num_classes=196) | |
| self.loss_fn = torch.nn.CrossEntropyLoss() | |
| def training_step(self, batch): | |
| images, targets = batch | |
| outputs = self.model(images) | |
| loss = self.loss_fn(outputs, targets) | |
| self.log("train_loss", loss) | |
| return loss | |
| def configure_optimizers(self): | |
| return torch.optim.AdamW(self.model.parameters(), lr=0.005) | |
| class ClassificationData(pl.LightningDataModule): | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((224, 224)), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def train_dataloader(self): | |
| # https://github.com/pytorch/vision/issues/7545#issuecomment-1631441616 | |
| train_dataset = datasets.StanfordCars(root=".", download=False, transform=self.transform) | |
| return torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=5) | |
| if __name__ == "__main__": | |
| model = LitClassification() | |
| data = ClassificationData() | |
| trainer = pl.Trainer(max_epochs=50, precision="16", log_every_n_steps=5) | |
| trainer.fit(model, data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment