Skip to content

Instantly share code, notes, and snippets.

@cjwcommuny
Last active November 20, 2021 01:03
Show Gist options
  • Save cjwcommuny/b54e0a93f7c9ca90a49f464a828fe356 to your computer and use it in GitHub Desktop.
Save cjwcommuny/b54e0a93f7c9ca90a49f464a828fe356 to your computer and use it in GitHub Desktop.
Detectron2 Extract Box Features
import cv2
import torch
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.model_zoo import model_zoo
from detectron2.utils.visualizer import Visualizer
def extract_box_features(model, image_cv, cfg):
device = cfg.MODEL.DEVICE
height, width = image_cv.shape[:2]
image = torch.as_tensor(image_cv.astype("float32").transpose(2, 0, 1), device=device)
inputs = [{"image": image, "height": height, "width": width}]
with torch.no_grad():
images = model.preprocess_image(inputs) # don't forget to preprocess
features = model.backbone(images.tensor) # set of cnn features
proposals, _ = model.proposal_generator(images, features, None) # RPN
# ---- roi_heads ---- begin
features_ = [features[f] for f in model.roi_heads.box_in_features]
box_features = model.roi_heads.box_pooler(features_, [x.proposal_boxes for x in proposals])
box_features = model.roi_heads.box_head(box_features) # features of all 1k candidates
predictions = model.roi_heads.box_predictor(box_features)
pred_instances, pred_inds = model.roi_heads.box_predictor.inference(predictions, proposals)
#
pred_instances = model.roi_heads.forward_with_given_boxes(features, pred_instances)
# ---- roi_heads ---- end
# output boxes, masks, scores, etc
pred_instances = model._postprocess(pred_instances, inputs, images.image_sizes) # scale box to orig size
# features of the proposed boxes
box_feats = box_features[pred_inds]
# set all object to the same class and do nms manually
# pred_instances = pred_instances[0]['instances']
# labels = torch.zeros(len(instances.scores), dtype=torch.long, device=instances.pred_boxes.device)
# box_idxes = batched_nms(instances.pred_boxes.tensor, instances.scores, labels, model.roi_heads.box_predictor.test_nms_thresh)
# pred_instances = pred_instances[box_idxes]
# box_feats = box_feats[box_idxes]
return pred_instances, box_feats
if __name__ == '__main__':
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml"))
predictor = DefaultPredictor(cfg)
model = predictor.model
image = cv2.imread('/workspace/cjw2/dataset/ViLT/flickr30k-images/flickr30k-images/8183834568.jpg')
pred_instances, box_feats = extract_box_features(model, image, cfg)
visualizer = Visualizer(image[:, :, ::-1], scale=1.2)
image_with_boxes = visualizer.draw_instance_predictions(pred_instances[0]['instances'].to("cpu")).get_image()[:, :, ::-1] # opencv format image
cv2.imwrite('image_with_boxes.jpg', image_with_boxes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment