Last active
June 23, 2020 00:46
-
-
Save ivder/2c86d8631a79a6a11f0d237c58baee95 to your computer and use it in GitHub Desktop.
Register dataset, training, inference using Detectron2 (Mask RCNN)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import detectron2 | |
from detectron2.utils.logger import setup_logger | |
setup_logger() | |
# import some common libraries | |
import numpy as np | |
import cv2 | |
import random | |
import datetime | |
# import some common detectron2 utilities | |
from detectron2 import model_zoo | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import Visualizer | |
from detectron2.data import MetadataCatalog | |
import os | |
import numpy as np | |
import json | |
from detectron2.structures import BoxMode | |
def get_damage_dicts(img_dir): | |
json_file = os.path.join(img_dir, "via_region_data.json") | |
with open(json_file) as f: | |
imgs_anns = json.load(f) | |
dataset_dicts = [] | |
for idx, v in enumerate(imgs_anns.values()): | |
record = {} | |
filename = os.path.join(img_dir, v["filename"]) | |
height, width = cv2.imread(filename).shape[:2] | |
record["file_name"] = filename | |
record["image_id"] = idx | |
record["height"] = height | |
record["width"] = width | |
annos = v["regions"] | |
objs = [] | |
for anno in annos: | |
assert not anno["region_attributes"] | |
anno = anno["shape_attributes"] | |
px = anno["all_points_x"] | |
py = anno["all_points_y"] | |
poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)] | |
poly = [p for x in poly for p in x] | |
obj = { | |
"bbox": [np.min(px), np.min(py), np.max(px), np.max(py)], | |
"bbox_mode": BoxMode.XYXY_ABS, | |
"segmentation": [poly], | |
"category_id": 0, | |
} | |
objs.append(obj) | |
record["annotations"] = objs | |
dataset_dicts.append(record) | |
return dataset_dicts | |
### prepare dataset | |
for d in ["train", "val"]: | |
get_damage_dicts("damage/" + d) | |
from detectron2.data import DatasetCatalog, MetadataCatalog | |
for d in ["train", "val"]: | |
DatasetCatalog.register("damage_" + d, lambda d=d: get_damage_dicts("damage/" + d)) | |
MetadataCatalog.get("damage_" + d).set(thing_classes=["damage"]) | |
damage_metadata = MetadataCatalog.get("damage_train") | |
### check annotation | |
x=0 | |
dataset_dicts = get_damage_dicts("damage/train") | |
for d in random.sample(dataset_dicts, 3): | |
img = cv2.imread(d["file_name"]) | |
visualizer = Visualizer(img[:, :, ::-1], metadata=damage_metadata, scale=0.5) | |
out = visualizer.draw_dataset_dict(d) | |
x +=1 | |
cv2.imwrite("test"+str(x)+".jpg", out.get_image()[:, :, ::-1]) | |
### training | |
from detectron2.engine import DefaultTrainer | |
from detectron2.config import get_cfg | |
cfg = get_cfg() | |
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) | |
cfg.DATASETS.TRAIN = ("damage_train",) | |
cfg.DATASETS.TEST = () | |
cfg.DATALOADER.NUM_WORKERS = 2 | |
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") # Let training initialize from model zoo | |
cfg.SOLVER.IMS_PER_BATCH = 2 | |
cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR | |
cfg.SOLVER.MAX_ITER = 1500 # 300 iterations seems good enough for this toy dataset; you may need to train longer for a practical dataset | |
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # faster, and good enough for this toy dataset (default: 512) | |
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # only has one class (damage) | |
#ucomment for training | |
''' | |
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) | |
trainer = DefaultTrainer(cfg) | |
trainer.resume_or_load(resume=False) | |
trainer.train() | |
''' | |
### inference | |
print ("doing inference") | |
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set the testing threshold for this model | |
cfg.DATASETS.TEST = ("damage_val", ) | |
predictor = DefaultPredictor(cfg) | |
from detectron2.utils.visualizer import ColorMode | |
dataset_dicts = get_damage_dicts("damage/val") | |
''' | |
for d in dataset_dicts: | |
im = cv2.imread(d["file_name"]) | |
''' | |
f = open("damage/yolotest.txt", "r") | |
for x in f: | |
x = x.strip() | |
im = cv2.imread(x) | |
print(x) | |
t1 = datetime.datetime.now() | |
outputs = predictor(im) | |
t2 = datetime.datetime.now() | |
tdiff = t2 - t1 | |
print (int(tdiff.total_seconds() * 1000)) | |
v = Visualizer(im[:, :, ::-1], | |
metadata=damage_metadata, | |
scale=0.8, | |
instance_mode=ColorMode.IMAGE_BW # remove the colors of unsegmented pixels | |
) | |
out = v.draw_instance_predictions(outputs["instances"].to("cpu")) | |
#cv2.imwrite("result/"+d["file_name"].split('/')[-1],out.get_image()[:, :, ::-1]) | |
cv2.imwrite("yolo_result/"+x.split('/')[-1],out.get_image()[:, :, ::-1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment