Last active
April 6, 2022 12:44
-
-
Save spirosdim/02d5b32a1b447cdef6e42f651300edca to your computer and use it in GitHub Desktop.
Detectron2 Inference script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from detectron2.modeling import build_model | |
from detectron2.checkpoint import DetectionCheckpointer | |
import detectron2.data.transforms as T | |
from detectron2.data.detection_utils import read_image | |
class MyPredictor: | |
""" | |
source: https://github.com/facebookresearch/detectron2/blob/08f617cc2c9276a48e7e7dc96ae946a5df23af3f/detectron2/engine/defaults.py#L252 | |
""" | |
def __init__(self, cfg): | |
self.cfg = cfg.clone() # cfg can be modified by model | |
self.model = build_model(self.cfg) | |
self.model.eval() | |
checkpointer = DetectionCheckpointer(self.model) | |
checkpointer.load(cfg.MODEL.WEIGHTS) | |
self.aug = T.ResizeShortestEdge( | |
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST | |
) | |
self.input_format = cfg.INPUT.FORMAT | |
assert self.input_format in ["RGB", "BGR"], self.input_format | |
def __call__(self, imgs_paths): | |
""" | |
Args: | |
imgs_paths (list): a list of image paths | |
Returns: | |
predictions (dict): | |
the output of the model. | |
See :detectron2 doc:`/tutorials/models` for details about the format. | |
""" | |
with torch.no_grad(): | |
# Apply pre-processing to image. | |
inputs =[] | |
for file_name in imgs_paths: | |
original_image = read_image(file_name, format=self.input_format) | |
height, width = original_image.shape[:2] | |
image = self.aug.get_transform(original_image).apply_image(original_image) | |
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) | |
inputs.append({"image": image, "height": height, "width": width}) | |
predictions = self.model(inputs) | |
assert len(inputs)==len(predictions) | |
return predictions |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment