Skip to content

Instantly share code, notes, and snippets.

@hkristen
Created January 4, 2026 09:01
Show Gist options
  • Select an option

  • Save hkristen/ee0b3f0e2324c9d606da3a7e537a2582 to your computer and use it in GitHub Desktop.

Select an option

Save hkristen/ee0b3f0e2324c9d606da3a7e537a2582 to your computer and use it in GitHub Desktop.
TiledInferenceCallback example for TorchGeo PR #3214
"""Example: TiledInferenceCallback for large-scale geospatial inference.
This script demonstrates how to use TorchGeo's TiledInferenceCallback to run
inference on large rasters that don't fit in memory. It handles:
- Tiled processing with configurable patch size and overlap
- Weighted blending at patch boundaries (cosine/linear)
- Automatic geospatial metadata preservation
- Chunked GeoTIFF output for large results
Usage:
python tiled_inference_example.py \\
--raster /path/to/large_raster.tif \\
--checkpoint model.ckpt \\
--output prediction.tif \\
--patch-size 256 \\
--overlap 64
Optional:
--roi /path/to/roi.geojson # Process only within ROI bounds
--delta 32 # Discard N pixels from patch edges
--batch-size 16 # Patches per batch
--cpu # Force CPU inference
"""
import argparse
import sys
from pathlib import Path
import torch
from lightning import Trainer
# Add current directory to path for development
sys.path.insert(0, str(Path(__file__).resolve().parent))
from torchgeo.callbacks import TiledInferenceCallback
from torchgeo.datamodules import GeoDataModule
from torchgeo.datasets import RasterDataset
from torchgeo.trainers import SemanticSegmentationTask
def main():
parser = argparse.ArgumentParser(
description="Run tiled inference on large geospatial rasters"
)
parser.add_argument("--raster", type=str, required=True, help="Input raster path")
parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path")
parser.add_argument("--output", type=str, required=True, help="Output GeoTIFF path")
parser.add_argument("--patch-size", type=int, default=256, help="Patch size (default: 256)")
parser.add_argument("--overlap", type=int, default=64, help="Overlap in pixels (default: 64)")
parser.add_argument("--delta", type=int, default=32, help="Edge pixels to discard (default: 32)")
parser.add_argument("--batch-size", type=int, default=16, help="Batch size (default: 16)")
parser.add_argument("--roi", type=str, help="Optional ROI GeoJSON/Shapefile")
parser.add_argument("--cpu", action="store_true", help="Force CPU inference")
args = parser.parse_args()
print("=" * 80)
print("TorchGeo TiledInferenceCallback Example")
print("=" * 80)
print(f"Raster: {args.raster}")
print(f"Checkpoint: {args.checkpoint}")
print(f"Output: {args.output}")
print(f"Patch size: {args.patch_size}")
print(f"Overlap: {args.overlap}")
print(f"Delta: {args.delta}")
print(f"Batch size: {args.batch_size}")
if args.roi:
print(f"ROI: {args.roi}")
print("=" * 80)
# Device
device = "cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nDevice: {device}")
# Load model
print("\nLoading model...")
task = SemanticSegmentationTask.load_from_checkpoint(
args.checkpoint,
map_location=device
)
task.eval()
print("✓ Model loaded")
# Load ROI if provided
roi = None
if args.roi:
import geopandas as gpd
from shapely.geometry import box as shapely_box
gdf = gpd.read_file(args.roi)
bounds = gdf.total_bounds
roi = shapely_box(*bounds)
print(f"✓ ROI loaded: {roi.bounds}")
# Create datamodule
print("\nCreating datamodule...")
stride = args.patch_size - 2 * args.overlap
datamodule = GeoDataModule(
dataset_class=RasterDataset,
batch_size=args.batch_size,
patch_size=args.patch_size,
length=None,
num_workers=0,
predict_roi=roi,
predict_stride=stride,
paths=args.raster
)
datamodule.setup(stage='predict')
print(f"✓ Dataset CRS: {datamodule.predict_dataset.crs}")
print(f"✓ Dataset bounds: {datamodule.predict_dataset.bounds}")
print(f"✓ Number of patches: {len(datamodule.predict_sampler)}")
# Create callback
print("\nSetting up TiledInferenceCallback...")
callback = TiledInferenceCallback(
output_path=args.output,
overlap=args.overlap,
delta=args.delta,
blend_method='cosine',
chunk_size=4096
)
print("✓ Callback initialized")
# Run inference
print("\n" + "=" * 80)
print("RUNNING INFERENCE")
print("=" * 80)
trainer = Trainer(
callbacks=[callback],
accelerator='cpu' if args.cpu else 'auto',
devices=1,
logger=False,
enable_checkpointing=False,
enable_progress_bar=True
)
print("\nStarting prediction...")
trainer.predict(task, datamodule=datamodule)
print("\n" + "=" * 80)
print("✓ INFERENCE COMPLETE!")
print("=" * 80)
print(f"✓ Output saved to: {args.output}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment