Skip to content

Instantly share code, notes, and snippets.

@isaaccorley
Last active September 23, 2022 20:36
Show Gist options
  • Select an option

  • Save isaaccorley/17f9ce96eb075550800d0c27bc9d8fe1 to your computer and use it in GitHub Desktop.

Select an option

Save isaaccorley/17f9ce96eb075550800d0c27bc9d8fe1 to your computer and use it in GitHub Desktop.
TorchGeo UCMerced Classifier Example
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"authorship_tag": "ABX9TyP+l8X4ZEIACc+Q7yDUHCGY",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/isaaccorley/17f9ce96eb075550800d0c27bc9d8fe1/torchgeo_ucmerced_classifier.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_eePWYKS1zIj"
},
"outputs": [],
"source": [
"!pip install torchgeo"
]
},
{
"cell_type": "code",
"source": [
"%load_ext tensorboard\n",
"\n",
"import os\n",
"\n",
"import pytorch_lightning as pl\n",
"\n",
"from torchgeo.datasets import UCMerced\n",
"from torchgeo.datamodules import UCMercedDataModule\n",
"from torchgeo.trainers import ClassificationTask"
],
"metadata": {
"id": "0_RVe5Dp3qjV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Parameters\n",
"data_dir = \"./data\"\n",
"num_classes = 21\n",
"channels = 3\n",
"batch_size = 4\n",
"num_workers = 2\n",
"backbone = \"resnet18\"\n",
"weights = \"imagenet\"\n",
"lr = 0.01\n",
"lr_schedule_patience = 5\n",
"epochs = 50"
],
"metadata": {
"id": "Vucgiszi3aZ3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Download dataset and dataset splits\n",
"dataset = UCMerced(data_dir, download=True, checksum=True)"
],
"metadata": {
"id": "fSFCJwfO5ClW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Instantiate datamodule, classifier task, and callbacks\n",
"datamodule = UCMercedDataModule(\n",
" root_dir=data_dir,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
")"
],
"metadata": {
"id": "CWfhPtsa4Xgg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"task = ClassificationTask(\n",
" classification_model=backbone,\n",
" weights=weights,\n",
" num_classes=num_classes,\n",
" in_channels=channels,\n",
" loss=\"ce\", \n",
" learning_rate=lr,\n",
" learning_rate_schedule_patience=lr_schedule_patience\n",
")"
],
"metadata": {
"id": "Voqgk8Qd3aX6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"checkpoint_callback = pl.callbacks.ModelCheckpoint(\n",
" monitor=\"val_loss\",\n",
" save_top_k=1,\n",
" save_last=True,\n",
")\n",
"early_stopping_callback = pl.callbacks.EarlyStopping(\n",
" monitor=\"val_loss\",\n",
" min_delta=0.00,\n",
" patience=10,\n",
")"
],
"metadata": {
"id": "M3GUmmAl3nrQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Train\n",
"trainer = pl.Trainer(\n",
" callbacks=[checkpoint_callback, early_stopping_callback],\n",
" max_epochs=epochs,\n",
" gpus=1\n",
")"
],
"metadata": {
"id": "cIzlh92t3npU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"trainer.fit(model=task, datamodule=datamodule)"
],
"metadata": {
"id": "Az9p03sI3nnm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Test\n",
"test_metrics = trainer.test(model=task, datamodule=datamodule)"
],
"metadata": {
"id": "1stLC7i88KFj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%tensorboard --logdir lightning_logs"
],
"metadata": {
"id": "yqKO8H-z3nla"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment