Skip to content

Instantly share code, notes, and snippets.

@PallawiSinghal
Created February 27, 2022 08:22
Show Gist options
  • Save PallawiSinghal/0595d00cc43bb218c04ddd8760e7f022 to your computer and use it in GitHub Desktop.
Save PallawiSinghal/0595d00cc43bb218c04ddd8760e7f022 to your computer and use it in GitHub Desktop.
import os
from point_rend.config import add_pointrend_config #most important
# from detectron2.utils.logger import setup_logger
# setup_logger()
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer
register_coco_instances("my_dataset_train", {}, "/code/detectron2/detectron2/instances_train2017.json", "/code/detectron2/detectron2/train2017")
register_coco_instances("my_dataset_val", {}, "/code/detectron2/detectron2/instances_val2017.json", "/code/detectron2/detectron2/val2017")
register_coco_instances("my_dataset_test", {}, "/code/detectron2/detectron2/instances_test2017.json", "/code/detectron2/detectron2/test2017")
cfg = get_cfg()
add_pointrend_config(cfg) #most important
cfg.merge_from_file("/code/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml")
#cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("/code/pointrend/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_X_101_32x8d_FPN_3x_coco.yaml")
cfg.MODEL.WEIGHTS ="detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl"
cfg.DATASETS.TRAIN = ("my_dataset_train",)
cfg.DATASETS.TEST = ("my_dataset_test",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 1
cfg.SOLVER.BASE_LR = 0.00025
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.POINT_HEAD.NUM_CLASSES = 1
cfg.INPUT.MAX_SIZE_TRAIN = 1333
cfg.INPUT.MIN_SIZE_TRAIN = (1024,1075,1126,1178,1230,1280)
cfg.INPUT.MAX_SIZE_TEST=1333
cfg.INPUT.MIN_SIZE_TEST=1280
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST =0.7
cfg.OUTPUT_DIR = "/code/detectron2/detectron2/output/"
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
print(cfg.dump())
with open("/code/detectron2/detectron2/output/pointrend_custom_mask_rcnn_X_101_32x8d_FPN_3x_feb_data_train_1_1280_scale_jitter_005.yaml", "w") as f:
f.write(cfg.dump())
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment