Skip to content

Instantly share code, notes, and snippets.

@MendelXu
Last active August 18, 2023 02:28
Show Gist options
  • Save MendelXu/105c7b91ba4b59b75acd488f6304b50f to your computer and use it in GitHub Desktop.
Save MendelXu/105c7b91ba4b59b75acd488f6304b50f to your computer and use it in GitHub Desktop.
Code to ensemble multiple models
try:
# ignore ShapelyDeprecationWarning from fvcore
import warnings
from shapely.errors import ShapelyDeprecationWarning
warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning)
except:
pass
import copy
import itertools
import logging
import os
from collections import OrderedDict, defaultdict
from typing import Any, Dict, List, Set
import detectron2.utils.comm as comm
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import (
DefaultTrainer,
default_argument_parser,
default_setup,
launch,
)
from detectron2.evaluation import (
CityscapesInstanceEvaluator,
CityscapesSemSegEvaluator,
COCOEvaluator,
COCOPanopticEvaluator,
DatasetEvaluators,
LVISEvaluator,
SemSegEvaluator,
verify_results,
)
from torch.nn.parallel import DistributedDataParallel
from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
from detectron2.solver.build import maybe_add_gradient_clipping
from detectron2.utils.logger import setup_logger
from tabulate import tabulate
from termcolor import colored
from san.models import layers
from san import (
COCOInstanceNewBaselineDatasetMapper,
COCOPanopticNewBaselineDatasetMapper,
InstanceSegEvaluator,
MaskFormerInstanceDatasetMapper,
MaskFormerPanopticDatasetMapper,
MaskFormerSemanticDatasetMapper,
SemanticSegmentorWithTTA,
add_san_config,
)
from san.data import build_detection_test_loader, build_detection_train_loader
from san.utils import WandbWriter, setup_wandb
from san.utils.hooks import ModelAfterStepHook
from torch import nn
class ModelEnsemble(nn.Module):
"""
A SemanticSegmentor with test-time augmentation enabled.
Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
"""
def __init__(self, models):
"""
Args:
cfg (CfgNode):
model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
tta_mapper (callable): takes a dataset dict and returns a list of
augmented versions of the dataset dict. Defaults to
`DatasetMapperTTA(cfg)`.
batch_size (int): batch the augmented images into this batch size for inference.
"""
super().__init__()
self.models = nn.ModuleList()
for model in models:
if isinstance(model, DistributedDataParallel):
model = model.module
self.models.append(model)
def __call__(self, batched_inputs):
"""
Same input/output format as :meth:`SemanticSegmentor.forward`
"""
results = []
for model in self.models:
results.append(model(batched_inputs))
final_results = []
for predictions in zip(*results):
sem_segs = [pred["sem_seg"].cpu() for pred in predictions]
min_cls = min([sem_seg.shape[0] for sem_seg in sem_segs])
assert len(sem_segs)==2
w = float(os.environ["EWEIGHT"])
sem_seg = torch.pow(sem_segs[0][:min_cls],w)*torch.pow(sem_segs[1][:min_cls],1-w)
# sem_seg = torch.stack(sem_segs, dim=0).mean(dim=0)
torch.cuda.empty_cache()
final_results.append({"sem_seg": sem_seg})
return final_results
class Trainer(DefaultTrainer):
"""
Extension of the Trainer class adapted to MaskFormer.
"""
def build_hooks(self):
rets = super().build_hooks()
rets.append(ModelAfterStepHook(self.model))
return rets
def build_writers(self):
"""
Build a list of writers to be used. By default it contains
writers that write metrics to the screen,
a json file, and a tensorboard event file respectively.
If you'd like a different list of writers, you can overwrite it in
your trainer.
Returns:
list[EventWriter]: a list of :class:`EventWriter` objects.
It is now implemented by:
::
return [
CommonMetricPrinter(self.max_iter),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
]
"""
writers = super().build_writers()
writers[-1] = WandbWriter()
return writers
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
"""
Create evaluator(s) for a given dataset.
This uses the special metadata "evaluator_type" associated with each
builtin dataset. For your own dataset, you can simply create an
evaluator manually in your script and do not have to worry about the
hacky if-else logic here.
"""
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
# semantic segmentation
if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
evaluator_list.append(
SemSegEvaluator(
dataset_name,
distributed=True,
output_dir=output_folder,
)
)
# instance segmentation
if evaluator_type == "coco":
evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
# panoptic segmentation
if evaluator_type in [
"coco_panoptic_seg",
"ade20k_panoptic_seg",
"cityscapes_panoptic_seg",
"mapillary_vistas_panoptic_seg",
]:
if cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON:
evaluator_list.append(
COCOPanopticEvaluator(dataset_name, output_folder)
)
# COCO
if (
evaluator_type == "coco_panoptic_seg"
and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON
):
evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
if (
evaluator_type == "coco_panoptic_seg"
and cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON
):
evaluator_list.append(
SemSegEvaluator(
dataset_name, distributed=True, output_dir=output_folder
)
)
# Mapillary Vistas
if (
evaluator_type == "mapillary_vistas_panoptic_seg"
and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON
):
evaluator_list.append(
InstanceSegEvaluator(dataset_name, output_dir=output_folder)
)
if (
evaluator_type == "mapillary_vistas_panoptic_seg"
and cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON
):
evaluator_list.append(
SemSegEvaluator(
dataset_name, distributed=True, output_dir=output_folder
)
)
# Cityscapes
if evaluator_type == "cityscapes_instance":
assert (
torch.cuda.device_count() > comm.get_rank()
), "CityscapesEvaluator currently do not work with multiple machines."
return CityscapesInstanceEvaluator(dataset_name)
if evaluator_type == "cityscapes_sem_seg":
assert (
torch.cuda.device_count() > comm.get_rank()
), "CityscapesEvaluator currently do not work with multiple machines."
return CityscapesSemSegEvaluator(dataset_name)
if evaluator_type == "cityscapes_panoptic_seg":
if cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON:
assert (
torch.cuda.device_count() > comm.get_rank()
), "CityscapesEvaluator currently do not work with multiple machines."
evaluator_list.append(CityscapesSemSegEvaluator(dataset_name))
if cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON:
assert (
torch.cuda.device_count() > comm.get_rank()
), "CityscapesEvaluator currently do not work with multiple machines."
evaluator_list.append(CityscapesInstanceEvaluator(dataset_name))
# ADE20K
if (
evaluator_type == "ade20k_panoptic_seg"
and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON
):
evaluator_list.append(
InstanceSegEvaluator(dataset_name, output_dir=output_folder)
)
# LVIS
if evaluator_type == "lvis":
return LVISEvaluator(dataset_name, output_dir=output_folder)
if len(evaluator_list) == 0:
raise NotImplementedError(
"no Evaluator for the dataset {} with the type {}".format(
dataset_name, evaluator_type
)
)
elif len(evaluator_list) == 1:
return evaluator_list[0]
return DatasetEvaluators(evaluator_list)
@classmethod
def build_train_loader(cls, cfg):
# Semantic segmentation dataset mapper
if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic":
mapper = MaskFormerSemanticDatasetMapper(cfg, True)
return build_detection_train_loader(cfg, mapper=mapper)
# Panoptic segmentation dataset mapper
elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic":
mapper = MaskFormerPanopticDatasetMapper(cfg, True)
return build_detection_train_loader(cfg, mapper=mapper)
# Instance segmentation dataset mapper
elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_instance":
mapper = MaskFormerInstanceDatasetMapper(cfg, True)
return build_detection_train_loader(cfg, mapper=mapper)
# coco instance segmentation lsj new baseline
elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_instance_lsj":
mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)
return build_detection_train_loader(cfg, mapper=mapper)
# coco panoptic segmentation lsj new baseline
elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj":
mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True)
return build_detection_train_loader(cfg, mapper=mapper)
else:
mapper = None
return build_detection_train_loader(cfg, mapper=mapper)
@classmethod
def build_test_loader(cls, cfg, dataset_name):
"""
Returns:
iterable
It now calls :func:`detectron2.data.build_detection_test_loader`.
Overwrite it if you'd like a different data loader.
"""
return build_detection_test_loader(cfg, dataset_name)
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
return build_lr_scheduler(cfg, optimizer)
@classmethod
def build_optimizer(cls, cfg, model):
weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
defaults = {}
defaults["lr"] = cfg.SOLVER.BASE_LR
defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
layers.LayerNorm,
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
hyperparams["param_name"] = ".".join([module_name, module_param_name])
if "backbone" in module_name:
hyperparams["lr"] = (
hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
)
if "masked_pooling_layer" in module_name:
hyperparams["lr"] = (
hyperparams["lr"] * cfg.SOLVER.CLIPHEAD_MULTIPLIER
)
if (
"relative_position_bias_table" in module_param_name
or "absolute_pos_embed" in module_param_name
or "positional_embedding" in module_param_name
or "pos_embed" in module_param_name
or "query_token" in module_param_name
):
print(module_param_name)
hyperparams["weight_decay"] = 0.0
if isinstance(module, norm_module_types):
hyperparams["weight_decay"] = weight_decay_norm
if isinstance(module, torch.nn.Embedding):
hyperparams["weight_decay"] = weight_decay_embed
params.append({"params": [value], **hyperparams})
def maybe_add_full_model_gradient_clipping(optim):
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(
*[x["params"] for x in self.param_groups]
)
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
return FullModelGradientClippingOptimizer if enable else optim
optimizer_type = cfg.SOLVER.OPTIMIZER
if optimizer_type == "SGD":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
)
elif optimizer_type == "ADAMW":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
# display the lr and wd of each param group in a table
optim_info = defaultdict(list)
total_params_size = 0
for group in optimizer.param_groups:
optim_info["Param Name"].append(group["param_name"])
optim_info["Param Shape"].append(
"X".join([str(x) for x in list(group["params"][0].shape)])
)
total_params_size += group["params"][0].numel()
optim_info["Lr"].append(group["lr"])
optim_info["Wd"].append(group["weight_decay"])
# Counting the number of parameters
optim_info["Param Name"].append("Total")
optim_info["Param Shape"].append(
"{:.2f}M".format(total_params_size / 1024 / 1024)
)
optim_info["Lr"].append("-")
optim_info["Wd"].append("-")
table = tabulate(
list(zip(*optim_info.values())),
headers=optim_info.keys(),
tablefmt="grid",
floatfmt=".2e",
stralign="center",
numalign="center",
)
logger = logging.getLogger("e2e")
logger.info("Optimizer Info:\n{}\n".format(table))
return optimizer
@classmethod
def test_with_TTA(cls, cfg, model):
logger = logging.getLogger("detectron2.trainer")
# In the end of training, run an evaluation with TTA.
logger.info("Running inference with test-time augmentation ...")
model = SemanticSegmentorWithTTA(cfg, model)
evaluators = [
cls.build_evaluator(
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
)
for name in cfg.DATASETS.TEST
]
res = cls.test(cfg, model, evaluators)
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
return res
def setup(args, logger=False):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
# for poly lr schedule
add_deeplab_config(cfg)
add_san_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
# Setup logger for "mask_former" module
if logger:
setup_logger(
output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="e2e"
)
return cfg
def main(args_list):
models = []
for args in args_list:
cfg = setup(args, logger=len(models) == 0)
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
models.append(model)
model = ModelEnsemble(models)
res = Trainer.test(cfg, model)
if __name__ == "__main__":
import sys
args_file = sys.argv[1]
if len(sys.argv) >2:
extra_opt = sys.argv[2:]
else:
extra_opt = []
with open(args_file, "r") as f:
args_list = f.readlines()
args_list = [
["--eval-only"] + x.split(" ")+extra_opt
for x in args_list
if not x.strip().startswith("#")
]
# extend args
if "{run}" in ' '.join([' '.join(args) for args in args_list]):
args_list_group = []
for run_id in range(5):
args_list_group.append(
[
[x.replace("{run}", str(run_id)) for x in args]
for args in args_list
]
)
else:
args_list_group = [args_list]
print("Will run {} experiments".format(len(args_list_group)))
for args_list in args_list_group:
print("=====================================")
print(args_list)
print("=====================================")
for _args_list in args_list_group:
args_list = [default_argument_parser().parse_args(args) for args in _args_list]
print("Command Line Args:", args_list)
args = args_list[0]
launch(
main,
8,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args_list,),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment