Created
September 21, 2022 12:28
-
-
Save ashnair1/369842c3f9833a3a49a54c8c284d2013 to your computer and use it in GitHub Desktop.
torchgeo predict using Caleb's SingleRasterDataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Copyright (c) Microsoft Corporation. All rights reserved. | |
# Licensed under the MIT License. | |
"""torchgeo model inference script.""" | |
import argparse | |
import os | |
from typing import Any, Callable, Dict, Optional, Tuple, Type, cast | |
import pytorch_lightning as pl | |
import rasterio as rio | |
import rasterio.errors as rioerr | |
import torch | |
import numpy as np | |
from torch.utils.data import DataLoader | |
from kornia.contrib import CombineTensorPatches, compute_padding | |
from omegaconf import OmegaConf | |
from torchgeo.datasets import RasterDataset, stack_samples | |
from torchgeo.datasets.geo import NonGeoDataset | |
from torchgeo.samplers import GridGeoSampler | |
from torchgeo.datamodules import ( | |
BigEarthNetDataModule, | |
ChesapeakeCVPRDataModule, | |
COWCCountingDataModule, | |
CycloneDataModule, | |
ETCI2021DataModule, | |
EuroSATDataModule, | |
InriaAerialImageLabelingDataModule, | |
LandCoverAIDataModule, | |
NAIPChesapeakeDataModule, | |
OSCDDataModule, | |
RESISC45DataModule, | |
SEN12MSDataModule, | |
So2SatDataModule, | |
UCMercedDataModule, | |
) | |
from torchgeo.trainers import ( | |
BYOLTask, | |
ClassificationTask, | |
MultiLabelClassificationTask, | |
RegressionTask, | |
SemanticSegmentationTask, | |
) | |
TASK_TO_MODULES_MAPPING: Dict[ | |
str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] | |
] = { | |
"bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule), | |
"byol": (BYOLTask, ChesapeakeCVPRDataModule), | |
"chesapeake_cvpr": (SemanticSegmentationTask, ChesapeakeCVPRDataModule), | |
"cowc_counting": (RegressionTask, COWCCountingDataModule), | |
"cyclone": (RegressionTask, CycloneDataModule), | |
"eurosat": (ClassificationTask, EuroSATDataModule), | |
"etci2021": (SemanticSegmentationTask, ETCI2021DataModule), | |
"inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule), | |
"landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule), | |
"naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule), | |
"oscd": (SemanticSegmentationTask, OSCDDataModule), | |
"resisc45": (ClassificationTask, RESISC45DataModule), | |
"sen12ms": (SemanticSegmentationTask, SEN12MSDataModule), | |
"so2sat": (ClassificationTask, So2SatDataModule), | |
"ucmerced": (ClassificationTask, UCMercedDataModule), | |
} | |
class SingleRasterDataset(RasterDataset): | |
num_bands = 4 | |
all_bands = [str(i) for i in range(1, num_bands + 1)] | |
def __init__(self, fn, transforms = None): | |
self.filename_regex = os.path.basename(fn) | |
bands = ["1", "2", "3"] | |
super().__init__(root=os.path.dirname(fn), bands=bands, transforms=transforms) | |
def main(config_dir: str, predict_on: str, output_dir: str, device: str) -> None: | |
os.makedirs(output_dir, exist_ok=True) | |
# Load checkpoint and config | |
conf = OmegaConf.load(os.path.join(config_dir, "experiment_config.yaml")) | |
ckpt = os.path.join(config_dir, "last.ckpt") | |
# Load model | |
task_name = conf.experiment.task | |
task: pl.LightningModule | |
if task_name not in TASK_TO_MODULES_MAPPING: | |
raise ValueError( | |
f"experiment.task={task_name} is not recognized as a valid task" | |
) | |
task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name] | |
task = task_class.load_from_checkpoint(ckpt) | |
task = task.to(device) | |
task.eval() | |
# Load preprocess tfm | |
preprocess = datamodule_class.preprocess | |
if len(os.listdir(output_dir)) > 0: | |
if conf.program.overwrite: | |
print( | |
f"WARNING! The output directory, {output_dir}, already exists, " | |
+ "we will overwrite data in it!" | |
) | |
else: | |
raise FileExistsError( | |
f"The predictions directory, {output_dir}, already exists and isn't " | |
+ "empty. We don't want to overwrite any existing results, exiting..." | |
) | |
target_fns = [] | |
target_profiles = [] | |
for root, dirs, fns in os.walk(predict_on): | |
for fn in fns: | |
fn = os.path.join(root, fn) | |
try: | |
with rio.open(fn) as src: | |
target_fns.append(fn) | |
target_profiles.append(src.profile.copy()) | |
except rioerr.RasterioIOError: | |
# Skip files that rasterio is unable to read | |
continue | |
for img, imgprofile in zip(target_fns, target_profiles): | |
ds = SingleRasterDataset(img, transforms=preprocess) | |
h, w = imgprofile["height"], imgprofile["width"] | |
tfm = imgprofile["transform"] | |
PATCH_SIZE = 512 | |
PADDING = 96 | |
STRIDE = PATCH_SIZE - 2*PADDING | |
sampler = GridGeoSampler(ds, size=PATCH_SIZE, stride=STRIDE) | |
dl = DataLoader(ds, sampler=sampler, batch_size=1, num_workers=0, collate_fn=stack_samples) | |
predictions = np.zeros((h, w), dtype=np.uint8) | |
for batch in dl: | |
images = batch["image"].to(device) | |
bboxes = batch["bbox"] | |
with torch.inference_mode(): | |
outputs = task(images) | |
outputs = outputs.argmax(axis=1).cpu().numpy() | |
for i in range(len(bboxes)): | |
bb = bboxes[i] | |
left, top = ~tfm * (bb.minx, bb.maxy) | |
right, bottom = ~tfm * (bb.maxx, bb.miny) | |
left, right, top, bottom = round(left), round(right), round(top), round(bottom) | |
assert right - left == PATCH_SIZE | |
assert bottom - top == PATCH_SIZE | |
predictions[top+PADDING:bottom-PADDING, left+PADDING:right-PADDING] = outputs[i][PADDING:-PADDING, PADDING:-PADDING] | |
profile = imgprofile | |
profile["count"] = 1 | |
profile["dtype"] = "uint8" | |
profile["compress"] = "lzw" | |
profile["predictor"] = 2 | |
outpath = f"{output_dir}/{os.path.basename(img)}" | |
# TODO: how to handle output filename with directories? | |
with rio.open(outpath, "w", **profile) as f: | |
f.write(predictions, 1) | |
if __name__ == "__main__": | |
# Taken from https://github.com/pangeo-data/cog-best-practices | |
_rasterio_best_practices = { | |
"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR", | |
"AWS_NO_SIGN_REQUEST": "YES", | |
"GDAL_MAX_RAW_BLOCK_CACHE_SIZE": "200000000", | |
"GDAL_SWATH_SIZE": "200000000", | |
"VSI_CURL_CACHE_SIZE": "200000000", | |
} | |
os.environ.update(_rasterio_best_practices) | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--config-dir", | |
type=str, | |
required=True, | |
help="Path to config-dir to load config and ckpt", | |
) | |
parser.add_argument( | |
"--predict_on", | |
type=str, | |
required=True, | |
help="Directory/Dataset to run inference on", | |
) | |
parser.add_argument( | |
"--output-dir", | |
type=str, | |
required=True, | |
help="Path to output_directory to save predicted mask geotiffs", | |
) | |
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) | |
args = parser.parse_args() | |
main(args.config_dir, args.predict_on, args.output_dir, args.device) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment