Skip to content

Instantly share code, notes, and snippets.

@spirosdim
Last active April 6, 2022 12:44
Show Gist options
  • Save spirosdim/02d5b32a1b447cdef6e42f651300edca to your computer and use it in GitHub Desktop.
Save spirosdim/02d5b32a1b447cdef6e42f651300edca to your computer and use it in GitHub Desktop.
Detectron2 Inference script
import torch
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer
import detectron2.data.transforms as T
from detectron2.data.detection_utils import read_image
class MyPredictor:
"""
source: https://github.com/facebookresearch/detectron2/blob/08f617cc2c9276a48e7e7dc96ae946a5df23af3f/detectron2/engine/defaults.py#L252
"""
def __init__(self, cfg):
self.cfg = cfg.clone() # cfg can be modified by model
self.model = build_model(self.cfg)
self.model.eval()
checkpointer = DetectionCheckpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
self.aug = T.ResizeShortestEdge(
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
)
self.input_format = cfg.INPUT.FORMAT
assert self.input_format in ["RGB", "BGR"], self.input_format
def __call__(self, imgs_paths):
"""
Args:
imgs_paths (list): a list of image paths
Returns:
predictions (dict):
the output of the model.
See :detectron2 doc:`/tutorials/models` for details about the format.
"""
with torch.no_grad():
# Apply pre-processing to image.
inputs =[]
for file_name in imgs_paths:
original_image = read_image(file_name, format=self.input_format)
height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
inputs.append({"image": image, "height": height, "width": width})
predictions = self.model(inputs)
assert len(inputs)==len(predictions)
return predictions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment