Skip to content

Instantly share code, notes, and snippets.

@spirosdim
Last active November 11, 2021 13:51
Show Gist options
  • Save spirosdim/1f5aa6bcfb17c892606d550f3bea7c03 to your computer and use it in GitHub Desktop.
Save spirosdim/1f5aa6bcfb17c892606d550f3bea7c03 to your computer and use it in GitHub Desktop.
import torch
from detectron2.solver.build import get_default_optimizer_params
from detectron2.solver.build import maybe_add_gradient_clipping
class MyTrainer(DefaultTrainer):
@classmethod
def build_optimizer(cls, cfg, model):
"""
Build an optimizer from config.
"""
params = get_default_optimizer_params(model)
return maybe_add_gradient_clipping(cfg, torch.optim.AdamW)(
params,
lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY
)
import os, json, yaml, datetime
from detectron2.config import get_cfg
base_model = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(base_model))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(base_model)
cfg.MODEL.BACKBONE.FREEZE_AT = 2
cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256]]
cfg.MODEL.RPN.NMS_THRESH = 0.7
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.PIXEL_MEAN = [123.675, 116.28, 103.53]
cfg.MODEL.PIXEL_STD = [58.395, 57.12, 57.375]
cfg.INPUT.FORMAT = "RGB"
cfg.INPUT.RANDOM_FLIP = "none"
cfg.INPUT.MIN_SIZE_TRAIN = (256, )
cfg.INPUT.MAX_SIZE_TRAIN = 256
cfg.INPUT.MIN_SIZE_TEST = 0
cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.MAX_ITER = 4000
cfg.SOLVER.BASE_LR = 8e-4
cfd.SOLVER.WEIGHT_DECAY = 0.0001
cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupCosineLR"
cfg.SOLVER.WARMUP_ITERS = int(0.2*cfg.SOLVER.MAX_ITER)
cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value"
cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0
cfg.SOLVER.AMP.ENABLED = True
now = datetime.datetime.now()
cfg.OUTPUT_DIR = f'./logs/{now.strftime("%Y%m%d_%H%M%S")}'
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
# Dump the config file in the output directory
with open(cfg.OUTPUT_DIR+'/config.yaml', 'w') as config:
yaml.dump(cfg, config)
config.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment