Created
January 4, 2026 09:01
-
-
Save hkristen/ee0b3f0e2324c9d606da3a7e537a2582 to your computer and use it in GitHub Desktop.
TiledInferenceCallback example for TorchGeo PR #3214
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Example: TiledInferenceCallback for large-scale geospatial inference. | |
| This script demonstrates how to use TorchGeo's TiledInferenceCallback to run | |
| inference on large rasters that don't fit in memory. It handles: | |
| - Tiled processing with configurable patch size and overlap | |
| - Weighted blending at patch boundaries (cosine/linear) | |
| - Automatic geospatial metadata preservation | |
| - Chunked GeoTIFF output for large results | |
| Usage: | |
| python tiled_inference_example.py \\ | |
| --raster /path/to/large_raster.tif \\ | |
| --checkpoint model.ckpt \\ | |
| --output prediction.tif \\ | |
| --patch-size 256 \\ | |
| --overlap 64 | |
| Optional: | |
| --roi /path/to/roi.geojson # Process only within ROI bounds | |
| --delta 32 # Discard N pixels from patch edges | |
| --batch-size 16 # Patches per batch | |
| --cpu # Force CPU inference | |
| """ | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| from lightning import Trainer | |
| # Add current directory to path for development | |
| sys.path.insert(0, str(Path(__file__).resolve().parent)) | |
| from torchgeo.callbacks import TiledInferenceCallback | |
| from torchgeo.datamodules import GeoDataModule | |
| from torchgeo.datasets import RasterDataset | |
| from torchgeo.trainers import SemanticSegmentationTask | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Run tiled inference on large geospatial rasters" | |
| ) | |
| parser.add_argument("--raster", type=str, required=True, help="Input raster path") | |
| parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path") | |
| parser.add_argument("--output", type=str, required=True, help="Output GeoTIFF path") | |
| parser.add_argument("--patch-size", type=int, default=256, help="Patch size (default: 256)") | |
| parser.add_argument("--overlap", type=int, default=64, help="Overlap in pixels (default: 64)") | |
| parser.add_argument("--delta", type=int, default=32, help="Edge pixels to discard (default: 32)") | |
| parser.add_argument("--batch-size", type=int, default=16, help="Batch size (default: 16)") | |
| parser.add_argument("--roi", type=str, help="Optional ROI GeoJSON/Shapefile") | |
| parser.add_argument("--cpu", action="store_true", help="Force CPU inference") | |
| args = parser.parse_args() | |
| print("=" * 80) | |
| print("TorchGeo TiledInferenceCallback Example") | |
| print("=" * 80) | |
| print(f"Raster: {args.raster}") | |
| print(f"Checkpoint: {args.checkpoint}") | |
| print(f"Output: {args.output}") | |
| print(f"Patch size: {args.patch_size}") | |
| print(f"Overlap: {args.overlap}") | |
| print(f"Delta: {args.delta}") | |
| print(f"Batch size: {args.batch_size}") | |
| if args.roi: | |
| print(f"ROI: {args.roi}") | |
| print("=" * 80) | |
| # Device | |
| device = "cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"\nDevice: {device}") | |
| # Load model | |
| print("\nLoading model...") | |
| task = SemanticSegmentationTask.load_from_checkpoint( | |
| args.checkpoint, | |
| map_location=device | |
| ) | |
| task.eval() | |
| print("✓ Model loaded") | |
| # Load ROI if provided | |
| roi = None | |
| if args.roi: | |
| import geopandas as gpd | |
| from shapely.geometry import box as shapely_box | |
| gdf = gpd.read_file(args.roi) | |
| bounds = gdf.total_bounds | |
| roi = shapely_box(*bounds) | |
| print(f"✓ ROI loaded: {roi.bounds}") | |
| # Create datamodule | |
| print("\nCreating datamodule...") | |
| stride = args.patch_size - 2 * args.overlap | |
| datamodule = GeoDataModule( | |
| dataset_class=RasterDataset, | |
| batch_size=args.batch_size, | |
| patch_size=args.patch_size, | |
| length=None, | |
| num_workers=0, | |
| predict_roi=roi, | |
| predict_stride=stride, | |
| paths=args.raster | |
| ) | |
| datamodule.setup(stage='predict') | |
| print(f"✓ Dataset CRS: {datamodule.predict_dataset.crs}") | |
| print(f"✓ Dataset bounds: {datamodule.predict_dataset.bounds}") | |
| print(f"✓ Number of patches: {len(datamodule.predict_sampler)}") | |
| # Create callback | |
| print("\nSetting up TiledInferenceCallback...") | |
| callback = TiledInferenceCallback( | |
| output_path=args.output, | |
| overlap=args.overlap, | |
| delta=args.delta, | |
| blend_method='cosine', | |
| chunk_size=4096 | |
| ) | |
| print("✓ Callback initialized") | |
| # Run inference | |
| print("\n" + "=" * 80) | |
| print("RUNNING INFERENCE") | |
| print("=" * 80) | |
| trainer = Trainer( | |
| callbacks=[callback], | |
| accelerator='cpu' if args.cpu else 'auto', | |
| devices=1, | |
| logger=False, | |
| enable_checkpointing=False, | |
| enable_progress_bar=True | |
| ) | |
| print("\nStarting prediction...") | |
| trainer.predict(task, datamodule=datamodule) | |
| print("\n" + "=" * 80) | |
| print("✓ INFERENCE COMPLETE!") | |
| print("=" * 80) | |
| print(f"✓ Output saved to: {args.output}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment