Created
April 9, 2021 05:43
-
-
Save trongan93/bfb9a5412eed091f43dd8bbbfaaebaae to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from detectron2.engine import DefaultTrainer | |
cfg = get_cfg() | |
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) | |
cfg.DATASETS.TRAIN = ("rareplanes_dataset_train",) | |
cfg.DATASETS.TEST = () | |
cfg.DATALOADER.NUM_WORKERS = 2 | |
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") # Let training initialize from model zoo | |
cfg.SOLVER.IMS_PER_BATCH = 2 | |
cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR | |
cfg.SOLVER.MAX_ITER = 300 # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset | |
cfg.SOLVER.STEPS = [] # do not decay learning rate | |
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # faster, and good enough for this toy dataset (default: 512) | |
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets) | |
# NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here. | |
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) | |
trainer = DefaultTrainer(cfg) | |
trainer.resume_or_load(resume=False) | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment