Last active
August 18, 2023 02:28
-
-
Save MendelXu/105c7b91ba4b59b75acd488f6304b50f to your computer and use it in GitHub Desktop.
Code to ensemble multiple models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
try: | |
# ignore ShapelyDeprecationWarning from fvcore | |
import warnings | |
from shapely.errors import ShapelyDeprecationWarning | |
warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning) | |
except: | |
pass | |
import copy | |
import itertools | |
import logging | |
import os | |
from collections import OrderedDict, defaultdict | |
from typing import Any, Dict, List, Set | |
import detectron2.utils.comm as comm | |
import torch | |
from detectron2.checkpoint import DetectionCheckpointer | |
from detectron2.config import get_cfg | |
from detectron2.data import MetadataCatalog | |
from detectron2.engine import ( | |
DefaultTrainer, | |
default_argument_parser, | |
default_setup, | |
launch, | |
) | |
from detectron2.evaluation import ( | |
CityscapesInstanceEvaluator, | |
CityscapesSemSegEvaluator, | |
COCOEvaluator, | |
COCOPanopticEvaluator, | |
DatasetEvaluators, | |
LVISEvaluator, | |
SemSegEvaluator, | |
verify_results, | |
) | |
from torch.nn.parallel import DistributedDataParallel | |
from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler | |
from detectron2.solver.build import maybe_add_gradient_clipping | |
from detectron2.utils.logger import setup_logger | |
from tabulate import tabulate | |
from termcolor import colored | |
from san.models import layers | |
from san import ( | |
COCOInstanceNewBaselineDatasetMapper, | |
COCOPanopticNewBaselineDatasetMapper, | |
InstanceSegEvaluator, | |
MaskFormerInstanceDatasetMapper, | |
MaskFormerPanopticDatasetMapper, | |
MaskFormerSemanticDatasetMapper, | |
SemanticSegmentorWithTTA, | |
add_san_config, | |
) | |
from san.data import build_detection_test_loader, build_detection_train_loader | |
from san.utils import WandbWriter, setup_wandb | |
from san.utils.hooks import ModelAfterStepHook | |
from torch import nn | |
class ModelEnsemble(nn.Module): | |
""" | |
A SemanticSegmentor with test-time augmentation enabled. | |
Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. | |
""" | |
def __init__(self, models): | |
""" | |
Args: | |
cfg (CfgNode): | |
model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. | |
tta_mapper (callable): takes a dataset dict and returns a list of | |
augmented versions of the dataset dict. Defaults to | |
`DatasetMapperTTA(cfg)`. | |
batch_size (int): batch the augmented images into this batch size for inference. | |
""" | |
super().__init__() | |
self.models = nn.ModuleList() | |
for model in models: | |
if isinstance(model, DistributedDataParallel): | |
model = model.module | |
self.models.append(model) | |
def __call__(self, batched_inputs): | |
""" | |
Same input/output format as :meth:`SemanticSegmentor.forward` | |
""" | |
results = [] | |
for model in self.models: | |
results.append(model(batched_inputs)) | |
final_results = [] | |
for predictions in zip(*results): | |
sem_segs = [pred["sem_seg"].cpu() for pred in predictions] | |
min_cls = min([sem_seg.shape[0] for sem_seg in sem_segs]) | |
assert len(sem_segs)==2 | |
w = float(os.environ["EWEIGHT"]) | |
sem_seg = torch.pow(sem_segs[0][:min_cls],w)*torch.pow(sem_segs[1][:min_cls],1-w) | |
# sem_seg = torch.stack(sem_segs, dim=0).mean(dim=0) | |
torch.cuda.empty_cache() | |
final_results.append({"sem_seg": sem_seg}) | |
return final_results | |
class Trainer(DefaultTrainer): | |
""" | |
Extension of the Trainer class adapted to MaskFormer. | |
""" | |
def build_hooks(self): | |
rets = super().build_hooks() | |
rets.append(ModelAfterStepHook(self.model)) | |
return rets | |
def build_writers(self): | |
""" | |
Build a list of writers to be used. By default it contains | |
writers that write metrics to the screen, | |
a json file, and a tensorboard event file respectively. | |
If you'd like a different list of writers, you can overwrite it in | |
your trainer. | |
Returns: | |
list[EventWriter]: a list of :class:`EventWriter` objects. | |
It is now implemented by: | |
:: | |
return [ | |
CommonMetricPrinter(self.max_iter), | |
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), | |
TensorboardXWriter(self.cfg.OUTPUT_DIR), | |
] | |
""" | |
writers = super().build_writers() | |
writers[-1] = WandbWriter() | |
return writers | |
@classmethod | |
def build_evaluator(cls, cfg, dataset_name, output_folder=None): | |
""" | |
Create evaluator(s) for a given dataset. | |
This uses the special metadata "evaluator_type" associated with each | |
builtin dataset. For your own dataset, you can simply create an | |
evaluator manually in your script and do not have to worry about the | |
hacky if-else logic here. | |
""" | |
if output_folder is None: | |
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") | |
evaluator_list = [] | |
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type | |
# semantic segmentation | |
if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]: | |
evaluator_list.append( | |
SemSegEvaluator( | |
dataset_name, | |
distributed=True, | |
output_dir=output_folder, | |
) | |
) | |
# instance segmentation | |
if evaluator_type == "coco": | |
evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) | |
# panoptic segmentation | |
if evaluator_type in [ | |
"coco_panoptic_seg", | |
"ade20k_panoptic_seg", | |
"cityscapes_panoptic_seg", | |
"mapillary_vistas_panoptic_seg", | |
]: | |
if cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON: | |
evaluator_list.append( | |
COCOPanopticEvaluator(dataset_name, output_folder) | |
) | |
# COCO | |
if ( | |
evaluator_type == "coco_panoptic_seg" | |
and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON | |
): | |
evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) | |
if ( | |
evaluator_type == "coco_panoptic_seg" | |
and cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON | |
): | |
evaluator_list.append( | |
SemSegEvaluator( | |
dataset_name, distributed=True, output_dir=output_folder | |
) | |
) | |
# Mapillary Vistas | |
if ( | |
evaluator_type == "mapillary_vistas_panoptic_seg" | |
and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON | |
): | |
evaluator_list.append( | |
InstanceSegEvaluator(dataset_name, output_dir=output_folder) | |
) | |
if ( | |
evaluator_type == "mapillary_vistas_panoptic_seg" | |
and cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON | |
): | |
evaluator_list.append( | |
SemSegEvaluator( | |
dataset_name, distributed=True, output_dir=output_folder | |
) | |
) | |
# Cityscapes | |
if evaluator_type == "cityscapes_instance": | |
assert ( | |
torch.cuda.device_count() > comm.get_rank() | |
), "CityscapesEvaluator currently do not work with multiple machines." | |
return CityscapesInstanceEvaluator(dataset_name) | |
if evaluator_type == "cityscapes_sem_seg": | |
assert ( | |
torch.cuda.device_count() > comm.get_rank() | |
), "CityscapesEvaluator currently do not work with multiple machines." | |
return CityscapesSemSegEvaluator(dataset_name) | |
if evaluator_type == "cityscapes_panoptic_seg": | |
if cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON: | |
assert ( | |
torch.cuda.device_count() > comm.get_rank() | |
), "CityscapesEvaluator currently do not work with multiple machines." | |
evaluator_list.append(CityscapesSemSegEvaluator(dataset_name)) | |
if cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON: | |
assert ( | |
torch.cuda.device_count() > comm.get_rank() | |
), "CityscapesEvaluator currently do not work with multiple machines." | |
evaluator_list.append(CityscapesInstanceEvaluator(dataset_name)) | |
# ADE20K | |
if ( | |
evaluator_type == "ade20k_panoptic_seg" | |
and cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON | |
): | |
evaluator_list.append( | |
InstanceSegEvaluator(dataset_name, output_dir=output_folder) | |
) | |
# LVIS | |
if evaluator_type == "lvis": | |
return LVISEvaluator(dataset_name, output_dir=output_folder) | |
if len(evaluator_list) == 0: | |
raise NotImplementedError( | |
"no Evaluator for the dataset {} with the type {}".format( | |
dataset_name, evaluator_type | |
) | |
) | |
elif len(evaluator_list) == 1: | |
return evaluator_list[0] | |
return DatasetEvaluators(evaluator_list) | |
@classmethod | |
def build_train_loader(cls, cfg): | |
# Semantic segmentation dataset mapper | |
if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic": | |
mapper = MaskFormerSemanticDatasetMapper(cfg, True) | |
return build_detection_train_loader(cfg, mapper=mapper) | |
# Panoptic segmentation dataset mapper | |
elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic": | |
mapper = MaskFormerPanopticDatasetMapper(cfg, True) | |
return build_detection_train_loader(cfg, mapper=mapper) | |
# Instance segmentation dataset mapper | |
elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_instance": | |
mapper = MaskFormerInstanceDatasetMapper(cfg, True) | |
return build_detection_train_loader(cfg, mapper=mapper) | |
# coco instance segmentation lsj new baseline | |
elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_instance_lsj": | |
mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True) | |
return build_detection_train_loader(cfg, mapper=mapper) | |
# coco panoptic segmentation lsj new baseline | |
elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj": | |
mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True) | |
return build_detection_train_loader(cfg, mapper=mapper) | |
else: | |
mapper = None | |
return build_detection_train_loader(cfg, mapper=mapper) | |
@classmethod | |
def build_test_loader(cls, cfg, dataset_name): | |
""" | |
Returns: | |
iterable | |
It now calls :func:`detectron2.data.build_detection_test_loader`. | |
Overwrite it if you'd like a different data loader. | |
""" | |
return build_detection_test_loader(cfg, dataset_name) | |
@classmethod | |
def build_lr_scheduler(cls, cfg, optimizer): | |
""" | |
It now calls :func:`detectron2.solver.build_lr_scheduler`. | |
Overwrite it if you'd like a different scheduler. | |
""" | |
return build_lr_scheduler(cfg, optimizer) | |
@classmethod | |
def build_optimizer(cls, cfg, model): | |
weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM | |
weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED | |
defaults = {} | |
defaults["lr"] = cfg.SOLVER.BASE_LR | |
defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY | |
norm_module_types = ( | |
torch.nn.BatchNorm1d, | |
torch.nn.BatchNorm2d, | |
torch.nn.BatchNorm3d, | |
torch.nn.SyncBatchNorm, | |
# NaiveSyncBatchNorm inherits from BatchNorm2d | |
torch.nn.GroupNorm, | |
torch.nn.InstanceNorm1d, | |
torch.nn.InstanceNorm2d, | |
torch.nn.InstanceNorm3d, | |
torch.nn.LayerNorm, | |
torch.nn.LocalResponseNorm, | |
layers.LayerNorm, | |
) | |
params: List[Dict[str, Any]] = [] | |
memo: Set[torch.nn.parameter.Parameter] = set() | |
for module_name, module in model.named_modules(): | |
for module_param_name, value in module.named_parameters(recurse=False): | |
if not value.requires_grad: | |
continue | |
# Avoid duplicating parameters | |
if value in memo: | |
continue | |
memo.add(value) | |
hyperparams = copy.copy(defaults) | |
hyperparams["param_name"] = ".".join([module_name, module_param_name]) | |
if "backbone" in module_name: | |
hyperparams["lr"] = ( | |
hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER | |
) | |
if "masked_pooling_layer" in module_name: | |
hyperparams["lr"] = ( | |
hyperparams["lr"] * cfg.SOLVER.CLIPHEAD_MULTIPLIER | |
) | |
if ( | |
"relative_position_bias_table" in module_param_name | |
or "absolute_pos_embed" in module_param_name | |
or "positional_embedding" in module_param_name | |
or "pos_embed" in module_param_name | |
or "query_token" in module_param_name | |
): | |
print(module_param_name) | |
hyperparams["weight_decay"] = 0.0 | |
if isinstance(module, norm_module_types): | |
hyperparams["weight_decay"] = weight_decay_norm | |
if isinstance(module, torch.nn.Embedding): | |
hyperparams["weight_decay"] = weight_decay_embed | |
params.append({"params": [value], **hyperparams}) | |
def maybe_add_full_model_gradient_clipping(optim): | |
# detectron2 doesn't have full model gradient clipping now | |
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE | |
enable = ( | |
cfg.SOLVER.CLIP_GRADIENTS.ENABLED | |
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" | |
and clip_norm_val > 0.0 | |
) | |
class FullModelGradientClippingOptimizer(optim): | |
def step(self, closure=None): | |
all_params = itertools.chain( | |
*[x["params"] for x in self.param_groups] | |
) | |
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) | |
super().step(closure=closure) | |
return FullModelGradientClippingOptimizer if enable else optim | |
optimizer_type = cfg.SOLVER.OPTIMIZER | |
if optimizer_type == "SGD": | |
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( | |
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM | |
) | |
elif optimizer_type == "ADAMW": | |
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( | |
params, cfg.SOLVER.BASE_LR | |
) | |
else: | |
raise NotImplementedError(f"no optimizer type {optimizer_type}") | |
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": | |
optimizer = maybe_add_gradient_clipping(cfg, optimizer) | |
# display the lr and wd of each param group in a table | |
optim_info = defaultdict(list) | |
total_params_size = 0 | |
for group in optimizer.param_groups: | |
optim_info["Param Name"].append(group["param_name"]) | |
optim_info["Param Shape"].append( | |
"X".join([str(x) for x in list(group["params"][0].shape)]) | |
) | |
total_params_size += group["params"][0].numel() | |
optim_info["Lr"].append(group["lr"]) | |
optim_info["Wd"].append(group["weight_decay"]) | |
# Counting the number of parameters | |
optim_info["Param Name"].append("Total") | |
optim_info["Param Shape"].append( | |
"{:.2f}M".format(total_params_size / 1024 / 1024) | |
) | |
optim_info["Lr"].append("-") | |
optim_info["Wd"].append("-") | |
table = tabulate( | |
list(zip(*optim_info.values())), | |
headers=optim_info.keys(), | |
tablefmt="grid", | |
floatfmt=".2e", | |
stralign="center", | |
numalign="center", | |
) | |
logger = logging.getLogger("e2e") | |
logger.info("Optimizer Info:\n{}\n".format(table)) | |
return optimizer | |
@classmethod | |
def test_with_TTA(cls, cfg, model): | |
logger = logging.getLogger("detectron2.trainer") | |
# In the end of training, run an evaluation with TTA. | |
logger.info("Running inference with test-time augmentation ...") | |
model = SemanticSegmentorWithTTA(cfg, model) | |
evaluators = [ | |
cls.build_evaluator( | |
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") | |
) | |
for name in cfg.DATASETS.TEST | |
] | |
res = cls.test(cfg, model, evaluators) | |
res = OrderedDict({k + "_TTA": v for k, v in res.items()}) | |
return res | |
def setup(args, logger=False): | |
""" | |
Create configs and perform basic setups. | |
""" | |
cfg = get_cfg() | |
# for poly lr schedule | |
add_deeplab_config(cfg) | |
add_san_config(cfg) | |
cfg.merge_from_file(args.config_file) | |
cfg.merge_from_list(args.opts) | |
cfg.freeze() | |
default_setup(cfg, args) | |
# Setup logger for "mask_former" module | |
if logger: | |
setup_logger( | |
output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="e2e" | |
) | |
return cfg | |
def main(args_list): | |
models = [] | |
for args in args_list: | |
cfg = setup(args, logger=len(models) == 0) | |
model = Trainer.build_model(cfg) | |
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( | |
cfg.MODEL.WEIGHTS, resume=args.resume | |
) | |
models.append(model) | |
model = ModelEnsemble(models) | |
res = Trainer.test(cfg, model) | |
if __name__ == "__main__": | |
import sys | |
args_file = sys.argv[1] | |
if len(sys.argv) >2: | |
extra_opt = sys.argv[2:] | |
else: | |
extra_opt = [] | |
with open(args_file, "r") as f: | |
args_list = f.readlines() | |
args_list = [ | |
["--eval-only"] + x.split(" ")+extra_opt | |
for x in args_list | |
if not x.strip().startswith("#") | |
] | |
# extend args | |
if "{run}" in ' '.join([' '.join(args) for args in args_list]): | |
args_list_group = [] | |
for run_id in range(5): | |
args_list_group.append( | |
[ | |
[x.replace("{run}", str(run_id)) for x in args] | |
for args in args_list | |
] | |
) | |
else: | |
args_list_group = [args_list] | |
print("Will run {} experiments".format(len(args_list_group))) | |
for args_list in args_list_group: | |
print("=====================================") | |
print(args_list) | |
print("=====================================") | |
for _args_list in args_list_group: | |
args_list = [default_argument_parser().parse_args(args) for args in _args_list] | |
print("Command Line Args:", args_list) | |
args = args_list[0] | |
launch( | |
main, | |
8, | |
num_machines=args.num_machines, | |
machine_rank=args.machine_rank, | |
dist_url=args.dist_url, | |
args=(args_list,), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment