Skip to content

Instantly share code, notes, and snippets.

@ashnair1
Created September 21, 2022 12:28
Show Gist options
  • Save ashnair1/369842c3f9833a3a49a54c8c284d2013 to your computer and use it in GitHub Desktop.
Save ashnair1/369842c3f9833a3a49a54c8c284d2013 to your computer and use it in GitHub Desktop.
torchgeo predict using Caleb's SingleRasterDataset
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""torchgeo model inference script."""
import argparse
import os
from typing import Any, Callable, Dict, Optional, Tuple, Type, cast
import pytorch_lightning as pl
import rasterio as rio
import rasterio.errors as rioerr
import torch
import numpy as np
from torch.utils.data import DataLoader
from kornia.contrib import CombineTensorPatches, compute_padding
from omegaconf import OmegaConf
from torchgeo.datasets import RasterDataset, stack_samples
from torchgeo.datasets.geo import NonGeoDataset
from torchgeo.samplers import GridGeoSampler
from torchgeo.datamodules import (
BigEarthNetDataModule,
ChesapeakeCVPRDataModule,
COWCCountingDataModule,
CycloneDataModule,
ETCI2021DataModule,
EuroSATDataModule,
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
NAIPChesapeakeDataModule,
OSCDDataModule,
RESISC45DataModule,
SEN12MSDataModule,
So2SatDataModule,
UCMercedDataModule,
)
from torchgeo.trainers import (
BYOLTask,
ClassificationTask,
MultiLabelClassificationTask,
RegressionTask,
SemanticSegmentationTask,
)
TASK_TO_MODULES_MAPPING: Dict[
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]]
] = {
"bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule),
"byol": (BYOLTask, ChesapeakeCVPRDataModule),
"chesapeake_cvpr": (SemanticSegmentationTask, ChesapeakeCVPRDataModule),
"cowc_counting": (RegressionTask, COWCCountingDataModule),
"cyclone": (RegressionTask, CycloneDataModule),
"eurosat": (ClassificationTask, EuroSATDataModule),
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule),
"inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule),
"landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule),
"naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule),
"oscd": (SemanticSegmentationTask, OSCDDataModule),
"resisc45": (ClassificationTask, RESISC45DataModule),
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule),
"so2sat": (ClassificationTask, So2SatDataModule),
"ucmerced": (ClassificationTask, UCMercedDataModule),
}
class SingleRasterDataset(RasterDataset):
num_bands = 4
all_bands = [str(i) for i in range(1, num_bands + 1)]
def __init__(self, fn, transforms = None):
self.filename_regex = os.path.basename(fn)
bands = ["1", "2", "3"]
super().__init__(root=os.path.dirname(fn), bands=bands, transforms=transforms)
def main(config_dir: str, predict_on: str, output_dir: str, device: str) -> None:
os.makedirs(output_dir, exist_ok=True)
# Load checkpoint and config
conf = OmegaConf.load(os.path.join(config_dir, "experiment_config.yaml"))
ckpt = os.path.join(config_dir, "last.ckpt")
# Load model
task_name = conf.experiment.task
task: pl.LightningModule
if task_name not in TASK_TO_MODULES_MAPPING:
raise ValueError(
f"experiment.task={task_name} is not recognized as a valid task"
)
task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name]
task = task_class.load_from_checkpoint(ckpt)
task = task.to(device)
task.eval()
# Load preprocess tfm
preprocess = datamodule_class.preprocess
if len(os.listdir(output_dir)) > 0:
if conf.program.overwrite:
print(
f"WARNING! The output directory, {output_dir}, already exists, "
+ "we will overwrite data in it!"
)
else:
raise FileExistsError(
f"The predictions directory, {output_dir}, already exists and isn't "
+ "empty. We don't want to overwrite any existing results, exiting..."
)
target_fns = []
target_profiles = []
for root, dirs, fns in os.walk(predict_on):
for fn in fns:
fn = os.path.join(root, fn)
try:
with rio.open(fn) as src:
target_fns.append(fn)
target_profiles.append(src.profile.copy())
except rioerr.RasterioIOError:
# Skip files that rasterio is unable to read
continue
for img, imgprofile in zip(target_fns, target_profiles):
ds = SingleRasterDataset(img, transforms=preprocess)
h, w = imgprofile["height"], imgprofile["width"]
tfm = imgprofile["transform"]
PATCH_SIZE = 512
PADDING = 96
STRIDE = PATCH_SIZE - 2*PADDING
sampler = GridGeoSampler(ds, size=PATCH_SIZE, stride=STRIDE)
dl = DataLoader(ds, sampler=sampler, batch_size=1, num_workers=0, collate_fn=stack_samples)
predictions = np.zeros((h, w), dtype=np.uint8)
for batch in dl:
images = batch["image"].to(device)
bboxes = batch["bbox"]
with torch.inference_mode():
outputs = task(images)
outputs = outputs.argmax(axis=1).cpu().numpy()
for i in range(len(bboxes)):
bb = bboxes[i]
left, top = ~tfm * (bb.minx, bb.maxy)
right, bottom = ~tfm * (bb.maxx, bb.miny)
left, right, top, bottom = round(left), round(right), round(top), round(bottom)
assert right - left == PATCH_SIZE
assert bottom - top == PATCH_SIZE
predictions[top+PADDING:bottom-PADDING, left+PADDING:right-PADDING] = outputs[i][PADDING:-PADDING, PADDING:-PADDING]
profile = imgprofile
profile["count"] = 1
profile["dtype"] = "uint8"
profile["compress"] = "lzw"
profile["predictor"] = 2
outpath = f"{output_dir}/{os.path.basename(img)}"
# TODO: how to handle output filename with directories?
with rio.open(outpath, "w", **profile) as f:
f.write(predictions, 1)
if __name__ == "__main__":
# Taken from https://github.com/pangeo-data/cog-best-practices
_rasterio_best_practices = {
"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR",
"AWS_NO_SIGN_REQUEST": "YES",
"GDAL_MAX_RAW_BLOCK_CACHE_SIZE": "200000000",
"GDAL_SWATH_SIZE": "200000000",
"VSI_CURL_CACHE_SIZE": "200000000",
}
os.environ.update(_rasterio_best_practices)
parser = argparse.ArgumentParser()
parser.add_argument(
"--config-dir",
type=str,
required=True,
help="Path to config-dir to load config and ckpt",
)
parser.add_argument(
"--predict_on",
type=str,
required=True,
help="Directory/Dataset to run inference on",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Path to output_directory to save predicted mask geotiffs",
)
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
args = parser.parse_args()
main(args.config_dir, args.predict_on, args.output_dir, args.device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment