Created
September 10, 2020 12:33
-
-
Save vidit09/dad5b704ebe8c0468d75f72838d9d799 to your computer and use it in GitHub Desktop.
source domain training without DA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
r""" | |
Basic training script for PyTorch | |
""" | |
# Set up custom environment before nearly anything else is imported | |
# NOTE: this should be the first import (no not reorder) | |
from fcos_core.utils.env import setup_environment # noqa F401 isort:skip | |
import argparse | |
import os | |
import torch | |
from fcos_core.config import cfg | |
from fcos_core.data import make_data_loader | |
from fcos_core.solver import make_lr_scheduler | |
from fcos_core.solver import make_optimizer | |
from fcos_core.engine.inference import inference | |
from fcos_core.engine.trainer import do_train | |
from fcos_core.modeling.detector import build_detection_model | |
from fcos_core.modeling.backbone import build_backbone | |
from fcos_core.modeling.rpn.rpn import build_rpn | |
from fcos_core.utils.checkpoint import DetectronCheckpointer | |
from fcos_core.utils.collect_env import collect_env_info | |
from fcos_core.utils.comm import synchronize, \ | |
get_rank, is_pytorch_1_1_0_or_later | |
from fcos_core.utils.imports import import_file | |
from fcos_core.utils.logger import setup_logger | |
from fcos_core.utils.miscellaneous import mkdir | |
def train(cfg, local_rank, distributed): | |
#model = build_detection_model(cfg) | |
model = {} | |
device = torch.device(cfg.MODEL.DEVICE) | |
#model.to(device) | |
backbone = build_backbone(cfg).to(device) | |
fcos = build_rpn(cfg, backbone.out_channels).to(device) | |
model = {"backbone":backbone,"fcos":fcos} | |
if cfg.MODEL.USE_SYNCBN: | |
assert is_pytorch_1_1_0_or_later(), \ | |
"SyncBatchNorm is only available in pytorch >= 1.1.0" | |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
optimizer = {} | |
optimizer.update({"backbone" : make_optimizer(cfg, model['backbone'], name='backbone')}) | |
optimizer.update({"fcos" : make_optimizer(cfg, model['fcos'], name='fcos')}) | |
# optimizer = make_optimizer(cfg, model) | |
# scheduler = make_lr_scheduler(cfg, optimizer) | |
scheduler = {} | |
scheduler.update({"backbone" : make_lr_scheduler(cfg, optimizer["backbone"], name='backbone')}) | |
scheduler.update({"fcos" : make_lr_scheduler(cfg, optimizer["fcos"], name='fcos')}) | |
if distributed: | |
model['backbone'] = torch.nn.parallel.DistributedDataParallel( | |
model['backbone'], device_ids=[local_rank], output_device=local_rank, | |
# this should be removed if we update BatchNorm stats | |
broadcast_buffers=False, | |
) | |
model['fcos'] = torch.nn.parallel.DistributedDataParallel( | |
model['fcos'], device_ids=[local_rank], output_device=local_rank, | |
# this should be removed if we update BatchNorm stats | |
broadcast_buffers=False, | |
) | |
arguments = {} | |
arguments["iteration"] = 0 | |
arguments["use_dis_global"] = cfg.MODEL.ADV.USE_DIS_GLOBAL | |
arguments["use_dis_ca"] = cfg.MODEL.ADV.USE_DIS_CENTER_AWARE | |
arguments["ga_dis_lambda"] = cfg.MODEL.ADV.GA_DIS_LAMBDA | |
arguments["ca_dis_lambda"] = cfg.MODEL.ADV.CA_DIS_LAMBDA | |
arguments["use_feature_layers"] = [] | |
output_dir = cfg.OUTPUT_DIR | |
save_to_disk = get_rank() == 0 | |
checkpointer = DetectronCheckpointer( | |
cfg, model, optimizer, scheduler, output_dir, save_to_disk | |
) | |
extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) | |
arguments.update(extra_checkpoint_data) | |
data_loader = make_data_loader( | |
cfg, | |
is_train=True, | |
is_distributed=distributed, | |
start_iter=arguments["iteration"], | |
) | |
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD | |
do_train( | |
model, | |
data_loader, | |
optimizer, | |
scheduler, | |
checkpointer, | |
device, | |
checkpoint_period, | |
arguments, | |
) | |
return model | |
def run_test(cfg, model, distributed): | |
if distributed: | |
model = model.module | |
torch.cuda.empty_cache() # TODO check if it helps | |
iou_types = ("bbox",) | |
if cfg.MODEL.MASK_ON: | |
iou_types = iou_types + ("segm",) | |
if cfg.MODEL.KEYPOINT_ON: | |
iou_types = iou_types + ("keypoints",) | |
output_folders = [None] * len(cfg.DATASETS.TEST) | |
dataset_names = cfg.DATASETS.TEST | |
if cfg.OUTPUT_DIR: | |
for idx, dataset_name in enumerate(dataset_names): | |
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) | |
mkdir(output_folder) | |
output_folders[idx] = output_folder | |
data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) | |
for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val): | |
inference( | |
model, | |
data_loader_val, | |
dataset_name=dataset_name, | |
iou_types=iou_types, | |
box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, | |
device=cfg.MODEL.DEVICE, | |
expected_results=cfg.TEST.EXPECTED_RESULTS, | |
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, | |
output_folder=output_folder, | |
) | |
synchronize() | |
def main(): | |
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") | |
parser.add_argument( | |
"--config-file", | |
default="", | |
metavar="FILE", | |
help="path to config file", | |
type=str, | |
) | |
parser.add_argument("--local_rank", type=int, default=0) | |
parser.add_argument( | |
"--skip-test", | |
dest="skip_test", | |
help="Do not test the final model", | |
action="store_true", | |
) | |
parser.add_argument( | |
"opts", | |
help="Modify config options using the command-line", | |
default=None, | |
nargs=argparse.REMAINDER, | |
) | |
args = parser.parse_args() | |
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 | |
args.distributed = num_gpus > 1 | |
if args.distributed: | |
torch.cuda.set_device(args.local_rank) | |
torch.distributed.init_process_group( | |
backend="nccl", init_method="env://" | |
) | |
synchronize() | |
cfg.merge_from_file(args.config_file) | |
cfg.merge_from_list(args.opts) | |
cfg.freeze() | |
output_dir = cfg.OUTPUT_DIR | |
if output_dir: | |
mkdir(output_dir) | |
logger = setup_logger("fcos_core", output_dir, get_rank()) | |
logger.info("Using {} GPUs".format(num_gpus)) | |
logger.info(args) | |
logger.info("Collecting env info (might take some time)") | |
logger.info("\n" + collect_env_info()) | |
logger.info("Loaded configuration file {}".format(args.config_file)) | |
with open(args.config_file, "r") as cf: | |
config_str = "\n" + cf.read() | |
logger.info(config_str) | |
logger.info("Running with config:\n{}".format(cfg)) | |
model = train(cfg, args.local_rank, args.distributed) | |
if not args.skip_test: | |
run_test(cfg, model, args.distributed) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import datetime | |
import logging | |
import time | |
import torch | |
import torch.distributed as dist | |
from fcos_core.utils.comm import get_world_size, is_pytorch_1_1_0_or_later | |
from fcos_core.utils.metric_logger import MetricLogger | |
from fcos_core.structures.image_list import to_image_list | |
def foward_detector(model, images, targets=None, return_maps=False): | |
map_layer_to_index = {"P3": 0, "P4": 1, "P5": 2, "P6": 3, "P7": 4} | |
feature_layers = map_layer_to_index.keys() | |
model_backbone = model["backbone"] | |
model_fcos = model["fcos"] | |
images = to_image_list(images) | |
features = model_backbone(images.tensors) | |
f = { | |
layer: features[map_layer_to_index[layer]] | |
for layer in feature_layers | |
} | |
losses = {} | |
if model_fcos.training and targets is None: | |
# train G on target domain | |
proposals, proposal_losses, score_maps = model_fcos( | |
images, features, targets=None, return_maps=return_maps) | |
assert len(proposal_losses) == 1 and proposal_losses["zero"] == 0 # loss_dict should be empty dict | |
else: | |
# train G on source domain / inference | |
proposals, proposal_losses, score_maps = model_fcos( | |
images, features, targets=targets, return_maps=return_maps) | |
if model_fcos.training: | |
# training | |
m = { | |
layer: { | |
map_type: | |
score_maps[map_type][map_layer_to_index[layer]] | |
for map_type in score_maps | |
} | |
for layer in feature_layers | |
} | |
losses.update(proposal_losses) | |
return losses, f, m | |
else: | |
# inference | |
result = proposals | |
return result | |
def reduce_loss_dict(loss_dict): | |
""" | |
Reduce the loss dictionary from all processes so that process with rank | |
0 has the averaged results. Returns a dict with the same fields as | |
loss_dict, after reduction. | |
""" | |
world_size = get_world_size() | |
if world_size < 2: | |
return loss_dict | |
with torch.no_grad(): | |
loss_names = [] | |
all_losses = [] | |
for k in sorted(loss_dict.keys()): | |
loss_names.append(k) | |
all_losses.append(loss_dict[k]) | |
all_losses = torch.stack(all_losses, dim=0) | |
dist.reduce(all_losses, dst=0) | |
if dist.get_rank() == 0: | |
# only main process gets accumulated, so only divide by | |
# world_size in this case | |
all_losses /= world_size | |
reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} | |
return reduced_losses | |
def do_train( | |
model, | |
data_loader, | |
optimizer, | |
scheduler, | |
checkpointer, | |
device, | |
checkpoint_period, | |
arguments, | |
): | |
USE_DIS_GLOBAL = arguments["use_dis_global"] | |
USE_DIS_CENTER_AWARE = arguments["use_dis_ca"] | |
used_feature_layers = arguments["use_feature_layers"] | |
# dataloader | |
if USE_DIS_GLOBAL: | |
data_loader_source = data_loader["source"] | |
data_loader_target = data_loader["target"] | |
else: | |
data_loader_source = data_loader | |
# classified label of source domain and target domain | |
source_label = 0.0 | |
target_label = 1.0 | |
# dis_lambda | |
if USE_DIS_GLOBAL: | |
ga_dis_lambda = arguments["ga_dis_lambda"] | |
if USE_DIS_CENTER_AWARE: | |
ca_dis_lambda = arguments["ca_dis_lambda"] | |
# Start training | |
logger = logging.getLogger("fcos_core.trainer") | |
logger.info("Start training") | |
# model.train() | |
for k in model: | |
model[k].train() | |
meters = MetricLogger(delimiter=" ") | |
if USE_DIS_GLOBAL: | |
assert len(data_loader_source) == len(data_loader_target) | |
max_iter = max(len(data_loader_source), len(data_loader_target)) | |
else: | |
max_iter = len(data_loader_source) | |
data_loader_target = data_loader_source | |
start_iter = arguments["iteration"] | |
start_training_time = time.time() | |
end = time.time() | |
pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later() | |
for iteration, ((images_s, targets_s, _), (images_t, _, _)) \ | |
in enumerate(zip(data_loader_source, data_loader_target), start_iter): | |
data_time = time.time() - end | |
iteration = iteration + 1 | |
arguments["iteration"] = iteration | |
# in pytorch >= 1.1.0, scheduler.step() should be run after optimizer.step() | |
if not pytorch_1_1_0_or_later: | |
# scheduler.step() | |
for k in scheduler: | |
scheduler[k].step() | |
images_s = images_s.to(device) | |
targets_s = [target_s.to(device) for target_s in targets_s] | |
if USE_DIS_GLOBAL: | |
images_t = images_t.to(device) | |
# targets_t = [target_t.to(device) for target_t in targets_t] | |
# optimizer.zero_grad() | |
for k in optimizer: | |
optimizer[k].zero_grad() | |
########################################################################## | |
#################### (1): train G with source domain ##################### | |
########################################################################## | |
loss_dict, features_s, score_maps_s = foward_detector( | |
model, images_s, targets=targets_s, return_maps=True) | |
# rename loss to indicate domain | |
loss_dict = {k + "_gs": loss_dict[k] for k in loss_dict} | |
losses = sum(loss for loss in loss_dict.values()) | |
# reduce losses over all GPUs for logging purposes | |
loss_dict_reduced = reduce_loss_dict(loss_dict) | |
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | |
meters.update(loss_gs=losses_reduced, **loss_dict_reduced) | |
losses.backward(retain_graph=True) | |
del loss_dict, losses | |
########################################################################## | |
#################### (2): train D with source domain ##################### | |
########################################################################## | |
if USE_DIS_GLOBAL: | |
loss_dict = {} | |
for layer in used_feature_layers: | |
# detatch score_map | |
for map_type in score_maps_s[layer]: | |
score_maps_s[layer][map_type] = score_maps_s[layer][map_type].detach() | |
if USE_DIS_GLOBAL: | |
loss_dict["loss_adv_%s_ds" % layer] = \ | |
ga_dis_lambda * model["dis_%s" % layer](features_s[layer], source_label, domain='source') | |
if USE_DIS_CENTER_AWARE: | |
loss_dict["loss_adv_%s_CA_ds" % layer] = \ | |
ca_dis_lambda * model["dis_%s_CA" % layer](features_s[layer], source_label, score_maps_s[layer], domain='source') | |
losses = sum(loss for loss in loss_dict.values()) | |
# reduce losses over all GPUs for logging purposes | |
loss_dict_reduced = reduce_loss_dict(loss_dict) | |
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | |
meters.update(loss_ds=losses_reduced, **loss_dict_reduced) | |
losses.backward() | |
del loss_dict, losses | |
########################################################################## | |
#################### (3): train D with target domain ##################### | |
################################################################# | |
if USE_DIS_GLOBAL: | |
loss_dict, features_t, score_maps_t = foward_detector(model, images_t, return_maps=True) | |
assert len(loss_dict) == 1 and loss_dict["zero"] == 0 # loss_dict should be empty dict | |
# loss_dict["loss_adv_Pn"] = model_dis_Pn(features_t["Pn"], target_label, domain='target') | |
for layer in used_feature_layers: | |
# detatch score_map | |
for map_type in score_maps_t[layer]: | |
score_maps_t[layer][map_type] = score_maps_t[layer][map_type].detach() | |
if USE_DIS_GLOBAL: | |
loss_dict["loss_adv_%s_dt" % layer] = \ | |
ga_dis_lambda * model["dis_%s" % layer](features_t[layer], target_label, domain='target') | |
if USE_DIS_CENTER_AWARE: | |
loss_dict["loss_adv_%s_CA_dt" %layer] = \ | |
ca_dis_lambda * model["dis_%s_CA" % layer](features_t[layer], target_label, score_maps_t[layer], domain='target') | |
losses = sum(loss for loss in loss_dict.values()) | |
# del "zero" (useless after backward) | |
del loss_dict['zero'] | |
# reduce losses over all GPUs for logging purposes | |
loss_dict_reduced = reduce_loss_dict(loss_dict) | |
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | |
meters.update(loss_dt=losses_reduced, **loss_dict_reduced) | |
# saved GRL gradient | |
grad_list = [] | |
for layer in used_feature_layers: | |
def save_grl_grad(grad): | |
grad_list.append(grad) | |
features_t[layer].register_hook(save_grl_grad) | |
losses.backward() | |
# Uncomment to log GRL gradient | |
grl_grad = {} | |
grl_grad_log = {} | |
# grl_grad = { | |
# layer: grad_list[i] | |
# for i, layer in enumerate(used_feature_layers) | |
# } | |
# for layer in used_feature_layers: | |
# saved_grad = grl_grad[layer] | |
# grl_grad_log["grl_%s_abs_mean" % layer] = torch.mean( | |
# torch.abs(saved_grad)) * 10e4 | |
# grl_grad_log["grl_%s_mean" % layer] = torch.mean(saved_grad) * 10e6 | |
# grl_grad_log["grl_%s_std" % layer] = torch.std(saved_grad) * 10e6 | |
# grl_grad_log["grl_%s_max" % layer] = torch.max(saved_grad) * 10e6 | |
# grl_grad_log["grl_%s_min" % layer] = torch.min(saved_grad) * 10e6 | |
# meters.update(**grl_grad_log) | |
del loss_dict, losses, grad_list, grl_grad, grl_grad_log | |
########################################################################## | |
########################################################################## | |
########################################################################## | |
# optimizer.step() | |
for k in optimizer: | |
optimizer[k].step() | |
if pytorch_1_1_0_or_later: | |
# scheduler.step() | |
for k in scheduler: | |
scheduler[k].step() | |
# End of training | |
batch_time = time.time() - end | |
end = time.time() | |
meters.update(time=batch_time, data=data_time) | |
eta_seconds = meters.time.global_avg * (max_iter - iteration) | |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
if USE_DIS_GLOBAL: | |
sample_layer = used_feature_layers[0] # sample any one of used feature layer | |
if USE_DIS_GLOBAL: | |
sample_optimizer = optimizer["dis_%s" % sample_layer] | |
if USE_DIS_CENTER_AWARE: | |
sample_optimizer = optimizer["dis_%s_CA" % sample_layer] | |
if USE_DIS_GLOBAL: | |
if iteration % 20 == 0 or iteration == max_iter: | |
logger.info( | |
meters.delimiter.join([ | |
"eta: {eta}", | |
"iter: {iter}", | |
"{meters}", | |
"lr_backbone: {lr_backbone:.6f}", | |
"lr_fcos: {lr_fcos:.6f}", | |
"lr_dis: {lr_dis:.6f}", | |
"max mem: {memory:.0f}", | |
]).format( | |
eta=eta_string, | |
iter=iteration, | |
meters=str(meters), | |
lr_backbone=optimizer["backbone"].param_groups[0]["lr"], | |
lr_fcos=optimizer["fcos"].param_groups[0]["lr"], | |
lr_dis=sample_optimizer.param_groups[0]["lr"], | |
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, | |
)) | |
else: | |
if iteration % 20 == 0 or iteration == max_iter: | |
logger.info( | |
meters.delimiter.join([ | |
"eta: {eta}", | |
"iter: {iter}", | |
"{meters}", | |
"lr_backbone: {lr_backbone:.6f}", | |
"max mem: {memory:.0f}", | |
]).format( | |
eta=eta_string, | |
iter=iteration, | |
meters=str(meters), | |
lr_backbone=optimizer["backbone"].param_groups[0]["lr"], | |
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, | |
)) | |
if iteration % checkpoint_period == 0: | |
checkpointer.save("model_{:07d}".format(iteration), **arguments) | |
# if iteration > 5000 and iteration <= 6000 and iteration % 50 == 0: | |
# checkpointer.save("model_{:07d}".format(iteration), **arguments) | |
if iteration == max_iter: | |
checkpointer.save("model_final", **arguments) | |
total_training_time = time.time() - start_training_time | |
total_time_str = str(datetime.timedelta(seconds=total_training_time)) | |
logger.info("Total training time: {} ({:.4f} s / it)".format( | |
total_time_str, total_training_time / (max_iter))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment