Skip to content

Instantly share code, notes, and snippets.

@Borda
Created July 29, 2024 23:15
Show Gist options
  • Select an option

  • Save Borda/b419764e3ff28ab46d4e8a087e97b94e to your computer and use it in GitHub Desktop.

Select an option

Save Borda/b419764e3ff28ab46d4e8a087e97b94e to your computer and use it in GitHub Desktop.
Image classification - StanfordCars
import glob
import torch
import scipy
import random
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from train import LitClassification, ClassificationData
# load test association to the numerical labels
meta = scipy.io.loadmat('stanford_cars/devkit/cars_meta.mat')
labels = [it[0] for it in meta["class_names"][0]]
# load sample test image
ls_images = glob.glob('stanford_cars/cars_test/*.jpg')
random.shuffle(ls_images)
print(f"found {len(ls_images)} cars in test folder")
image = Image.open(ls_images[0]).convert('RGB')
# Preprocess the input image
preprocess = ClassificationData().transform
tensor = preprocess(image).unsqueeze(0) # Add batch dimension
# Replace with path to your trained checkpoint 'lightning_logs/version_x/checkpoints/...'
checkpoint_path = glob.glob("lightning_logs/version_*/checkpoints/*.ckpt")[0]
print(f"loading model from checkpoint '{checkpoint_path}'")
# Load the model
model = LitClassification.load_from_checkpoint(checkpoint_path).cpu()
model.eval()
# Get the model prediction
with torch.no_grad():
output = model.model(tensor)
pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
print(f"predistion suggest {pred} whic is '{labels[pred]}'")
# show the input images with the label
plt.figure(figsize=(12, 8))
plt.imshow(image)
plt.title(labels[pred])
plt.show()
import torch
from timm import create_model
from torchvision import transforms, datasets
import pytorch_lightning as pl
class LitClassification(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = create_model('resnet50', num_classes=196)
self.loss_fn = torch.nn.CrossEntropyLoss()
def training_step(self, batch):
images, targets = batch
outputs = self.model(images)
loss = self.loss_fn(outputs, targets)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=0.005)
class ClassificationData(pl.LightningDataModule):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def train_dataloader(self):
# https://github.com/pytorch/vision/issues/7545#issuecomment-1631441616
train_dataset = datasets.StanfordCars(root=".", download=False, transform=self.transform)
return torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=5)
if __name__ == "__main__":
model = LitClassification()
data = ClassificationData()
trainer = pl.Trainer(max_epochs=50, precision="16", log_every_n_steps=5)
trainer.fit(model, data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment