Last active
November 11, 2021 13:51
-
-
Save spirosdim/1f5aa6bcfb17c892606d550f3bea7c03 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from detectron2.solver.build import get_default_optimizer_params | |
from detectron2.solver.build import maybe_add_gradient_clipping | |
class MyTrainer(DefaultTrainer): | |
@classmethod | |
def build_optimizer(cls, cfg, model): | |
""" | |
Build an optimizer from config. | |
""" | |
params = get_default_optimizer_params(model) | |
return maybe_add_gradient_clipping(cfg, torch.optim.AdamW)( | |
params, | |
lr=cfg.SOLVER.BASE_LR, | |
weight_decay=cfg.SOLVER.WEIGHT_DECAY | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, json, yaml, datetime | |
from detectron2.config import get_cfg | |
base_model = "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml" | |
cfg = get_cfg() | |
cfg.merge_from_file(model_zoo.get_config_file(base_model)) | |
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(base_model) | |
cfg.MODEL.BACKBONE.FREEZE_AT = 2 | |
cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256]] | |
cfg.MODEL.RPN.NMS_THRESH = 0.7 | |
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256 | |
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 | |
cfg.MODEL.PIXEL_MEAN = [123.675, 116.28, 103.53] | |
cfg.MODEL.PIXEL_STD = [58.395, 57.12, 57.375] | |
cfg.INPUT.FORMAT = "RGB" | |
cfg.INPUT.RANDOM_FLIP = "none" | |
cfg.INPUT.MIN_SIZE_TRAIN = (256, ) | |
cfg.INPUT.MAX_SIZE_TRAIN = 256 | |
cfg.INPUT.MIN_SIZE_TEST = 0 | |
cfg.SOLVER.IMS_PER_BATCH = 4 | |
cfg.SOLVER.MAX_ITER = 4000 | |
cfg.SOLVER.BASE_LR = 8e-4 | |
cfd.SOLVER.WEIGHT_DECAY = 0.0001 | |
cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupCosineLR" | |
cfg.SOLVER.WARMUP_ITERS = int(0.2*cfg.SOLVER.MAX_ITER) | |
cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True | |
cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value" | |
cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0 | |
cfg.SOLVER.AMP.ENABLED = True | |
now = datetime.datetime.now() | |
cfg.OUTPUT_DIR = f'./logs/{now.strftime("%Y%m%d_%H%M%S")}' | |
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) | |
# Dump the config file in the output directory | |
with open(cfg.OUTPUT_DIR+'/config.yaml', 'w') as config: | |
yaml.dump(cfg, config) | |
config.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment