Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Last active May 22, 2024 02:20
Show Gist options
  • Save calebrob6/af17853f9d9e9cac2b817950226c37c9 to your computer and use it in GitHub Desktop.
Save calebrob6/af17853f9d9e9cac2b817950226c37c9 to your computer and use it in GitHub Desktop.
Runs inference on large satellite image scenes with SAM models
# Requires the `segment-geospatial` package https://samgeo.gishub.org/
import argparse
import os
import cv2
import numpy as np
import rasterio
import rasterio.features
import rasterio.transform
import rasterio.windows
from tqdm.contrib import itertools
import logging
from samgeo.fast_sam import SamGeo
def add_parser_args(parser):
parser.add_argument("--input_fn", type=str, help="Input GeoTIFF imagery")
parser.add_argument("--output_fn", type=str, help="Output GeoTIFF mask")
parser.add_argument(
"--device", type=int, default=0, help="Device to run inference on"
)
parser.add_argument(
"--patch_size", type=int, default=1024, help="Patch size for inference"
)
parser.add_argument(
"--upsample_size",
type=int,
default=1,
help="Upsample factor for input imagery",
)
parser.add_argument(
"--padding", type=int, default=256, help="Padding for input imagery"
)
parser.add_argument(
"--overwrite", action="store_true", help="Overwrite existing output file"
)
parser.add_argument(
"--skip_valid_checks", action="store_true", help="Skip validation checks"
)
parser.add_argument(
"--size_filter_pixels",
type=int,
default=100_000,
help="Filter size for mask postprocessing",
)
parser.add_argument(
"--iou",
type=float,
default=0.01,
help="Intersection over union threshold for mask postprocessing",
)
parser.add_argument(
"--confidence",
type=float,
default=0.00,
help="Confidence threshold for mask postprocessing",
)
parser.add_argument(
"--max_detections",
type=int,
default=10_000,
help="Maximum number of detections for mask postprocessing",
)
def run_sam_on_img(
img, sam, upsample_size, size_filter_pixels, **kwargs
):
t_height, t_width, _ = img.shape
# args for this aren't documented anywhere, see https://github.com/CASIA-IVA-Lab/FastSAM/blob/main/predict.py#L87 for options
sam.set_image(img, **kwargs)
mask = sam.everything_prompt(output=None)
if len(mask) > 0:
mask_sizes = mask.sum(dim=(1, 2))
new_mask = np.zeros(
(upsample_size * t_height, upsample_size * t_width), dtype=np.int32
)
for i in range(mask.shape[0]):
if mask_sizes[i] < size_filter_pixels:
new_mask += mask[i].cpu().numpy() > 0
mask = (new_mask > 0).astype(np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
mask = cv2.morphologyEx(
mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
)
else:
mask = np.zeros((t_height, t_width), dtype=np.uint8)
# Resize mask if necessary (happens when upsample size is > 1)
mask_height, mask_width = mask.shape
if mask_height != t_height or mask_width != t_width:
mask = cv2.resize(
mask.astype(np.uint8),
(t_width, t_height),
interpolation=cv2.INTER_NEAREST,
)
return mask
def main(args):
if not args.skip_valid_checks:
assert os.path.exist(args.input_fn)
assert args.input_fn.endswith(".tif")
assert args.output_fn.endswith(".tif")
if os.path.exists(args.output_fn) and not args.overwrite:
raise FileExistsError(f"Output file {args.output_fn} already exists")
elif os.path.exists(args.output_fn) and args.overwrite:
print("WARNING: Will overwrite existing output file")
# Load SAM
print("Loading SAM model...")
device = f"cuda:{args.device}"
sam = SamGeo(model="FastSAM-x.pt")
# Stop YOLO8 logging
LOGGER = logging.getLogger("ultralytics")
LOGGER.setLevel(logging.CRITICAL)
# Run SAM
patch_size = args.patch_size
padding = args.padding
stride = patch_size - 2 * padding
with rasterio.open(args.input_fn) as f:
height, width = f.shape
profile = f.profile
if f.count > 3:
print("Input imagery has more than 3 bands, will use only the first 3")
elif f.count < 3:
raise ValueError("Input imagery must have at least 3 bands")
outputs = np.zeros((height, width), dtype=np.uint8)
ys = list(range(0, height - stride, stride)) + [height - patch_size]
xs = list(range(0, width - stride, stride)) + [width - patch_size]
print("Running inference...")
for i, (y, x) in enumerate(itertools.product(ys, xs)):
window = rasterio.windows.Window(x, y, patch_size, patch_size)
with rasterio.open(args.input_fn) as f:
img = f.read(window=window).transpose(1, 2, 0)[:, :, :3].copy()
# TODO: It makes sense that some blurring here might help, but IDK
mask = run_sam_on_img(
img,
sam,
upsample_size=args.upsample_size,
size_filter_pixels=args.size_filter_pixels,
imgsz=int(patch_size * args.upsample_size),
iou=args.iou,
conf=args.confidence,
max_det=args.max_detections,
device=device,
)
if x == 0 or y == 0:
outputs[y : y + patch_size, x : x + patch_size] = mask
else:
outputs[
y + padding : y + patch_size - padding,
x + padding : x + patch_size - padding,
] = mask[padding:-padding, padding:-padding]
profile["count"] = 1
profile["nodata"] = 0
with rasterio.open(args.output_fn, "w", **profile) as f:
f.write(outputs, 1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_parser_args(parser)
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment