Created
March 25, 2021 09:03
-
-
Save hushell/e14b79bcb3cb3e14cc294628e27eadc4 to your computer and use it in GitHub Desktop.
zero-shot SIB for domain adaptation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import math | |
import itertools | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
from copy import deepcopy | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import torch.backends.cudnn as cudnn | |
# custom modules | |
from schedulers import get_scheduler | |
from optimizers import get_optimizer | |
from networks import get_network | |
from utils.metrics import AverageMeter | |
from utils.utils import to_device | |
# summary | |
from tensorboardX import SummaryWriter | |
def entropy_loss(x): | |
return torch.sum(-F.softmax(x, 1) * F.log_softmax(x, 1), 1).mean() | |
class DNINet(nn.Module): | |
def __init__(self, input_dims, dni_hidden_size=1024): | |
super(DNINet, self).__init__() | |
self.layer1 = nn.Sequential( | |
nn.Linear(input_dims, dni_hidden_size), | |
nn.ReLU(), | |
nn.BatchNorm1d(dni_hidden_size) | |
) | |
self.layer2 = nn.Sequential( | |
nn.Linear(dni_hidden_size, dni_hidden_size), | |
nn.ReLU(), | |
nn.BatchNorm1d(dni_hidden_size) | |
) | |
self.layer3 = nn.Linear(dni_hidden_size, input_dims) | |
# self.apply(weight_init) | |
for m in self.modules(): | |
if isinstance(m, nn.BatchNorm1d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.Linear): | |
nn.init.kaiming_normal_(m.weight) | |
nn.init.normal_(m.bias) | |
def forward(self, x): | |
out = self.layer1(x) | |
out = self.layer2(out) | |
out = self.layer3(out) | |
return out | |
class SIBNet(nn.Module): | |
def __init__(self, args): | |
super(SIBNet, self).__init__() | |
self.args = args | |
self.n_steps = args.training.n_steps | |
# meta-model: feature-net & synthetic-grad-net | |
self.feat_net = get_network(args.network.arch)(output='feature') | |
self.feat_dim = self.feat_net.feat_dim | |
self.dni = DNINet(args.n_classes) | |
# meta-model: classification | |
self.theta = nn.Parameter(torch.FloatTensor(args.n_classes, self.feat_dim)) | |
nn.init.kaiming_uniform_(self.theta, a=math.sqrt(5)) | |
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.theta) | |
bound = 1 / math.sqrt(fan_in) | |
self.bias_cls = nn.Parameter(torch.FloatTensor(args.n_classes)) | |
nn.init.uniform_(self.bias_cls, -bound, bound) | |
self.up_dim = self.feat_dim | |
self.up = nn.Parameter(torch.FloatTensor(self.up_dim, args.n_classes)) | |
nn.init.kaiming_normal_(self.up) | |
# meta-model: auxilary self-sup (e.g. rotation) | |
self.rho = nn.Parameter(torch.FloatTensor(args.aux_classes+1, self.feat_dim + self.up_dim)) | |
nn.init.kaiming_uniform_(self.rho, a=math.sqrt(5)) | |
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.rho) | |
bound = 1 / math.sqrt(fan_in) | |
self.bias_aux = nn.Parameter(torch.FloatTensor(args.aux_classes+1)) | |
nn.init.uniform_(self.bias_aux, -bound, bound) | |
# loss | |
self.class_loss_func = nn.CrossEntropyLoss() | |
def multitask_predict(self, features, cls_weights, aux_weights): | |
''' | |
(n x nFeat, nClass x nFeat, nRot x nFeat) -> n x nClass x nRot | |
''' | |
tensor_cls = F.linear(features, cls_weights, self.bias_cls) | |
tensor_up = torch.einsum('nk,dk->nd', tensor_cls, self.up) | |
tensor_cat = torch.cat([features, tensor_up], dim=1) # n(d*2) | |
tensor_aux = F.linear(tensor_cat, aux_weights, self.bias_aux) | |
return tensor_cls, tensor_aux | |
def self_sup_update(self, aux_labels, theta, features, lr, n_steps=1): | |
theta = theta.clone() # TODO: check this doesn't affect theta in meta-model | |
for t in range(n_steps): | |
cls_scores, aux_scores = self.multitask_predict(features, theta, self.rho.detach()) | |
aux_loss = self.class_loss_func(aux_scores, aux_labels) | |
grad = torch.autograd.grad([aux_loss], [theta], | |
create_graph=True, retain_graph=True, only_inputs=True)[0] # nClass x nFeat | |
# perform GD | |
theta = theta - lr * grad | |
return theta | |
def synthetic_grad_update(self, theta, features, lr, n_steps=3): | |
theta = theta.clone() | |
for t in range(n_steps): | |
cls_scores, aux_scores = self.multitask_predict(features, theta, self.rho.detach()) | |
grad_logit = self.dni(cls_scores) # n x nClass | |
grad = torch.autograd.grad([cls_scores], [theta], | |
grad_outputs=[grad_logit], | |
create_graph=True, retain_graph=True, | |
only_inputs=True)[0] # nClass x nFeat | |
# perform GD | |
theta = theta - lr * grad | |
return theta | |
def forward(self, img, lr, aux_labels, cls_labels = None, sib_update=True): | |
feat = self.feat_net(img) | |
# generate classification weights for this minibatch | |
theta = self.theta | |
if self.args.training.with_self_sup_update and sib_update: | |
theta = self.self_sup_update(aux_labels, theta, | |
feat.detach(), lr, n_steps=1) | |
if self.args.training.with_synthetic_grad_update and sib_update: | |
theta = self.synthetic_grad_update(theta, feat.detach(), lr, n_steps=self.n_steps) | |
cls_scores, aux_scores = self.multitask_predict(feat, theta, self.rho) | |
aux_loss = self.class_loss_func(aux_scores, aux_labels) | |
if cls_labels is None: | |
task_loss = entropy_loss(cls_scores[aux_labels==0]) | |
else: | |
task_loss = self.class_loss_func(cls_scores, cls_labels) | |
return aux_loss, task_loss, aux_scores, cls_scores | |
class SIBModel: | |
def __init__(self, args, logger): | |
self.args = args | |
self.logger = logger | |
self.writer = SummaryWriter(args.log_dir) | |
cudnn.enabled = True | |
self.task_type = args.get('task_type', 'DA') | |
print('Task type: {}'.format(self.task_type)) | |
# set up model | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = SIBNet(args) | |
self.model = self.model.to(self.device) | |
# optimizer & scheduler & loss | |
if args.mode == 'train': | |
# set up optimizer, lr scheduler and loss functions | |
optimizer = get_optimizer(self.args.training.optimizer) | |
optimizer_params = {k: v for k, v in self.args.training.optimizer.items() if k != "name"} | |
self.optimizer = optimizer(self.model.parameters(), **optimizer_params) | |
self.scheduler = get_scheduler(self.optimizer, self.args.training.lr_scheduler) | |
self.start_iter = 0 | |
# resume | |
if args.training.resume: | |
self.load(args.model_dir + '/' + args.training.resume) | |
elif args.mode == 'val': | |
self.load(os.path.join(args.model_dir, args.validation.model)) | |
else: | |
self.load(os.path.join(args.model_dir, args.testing.model)) | |
self.src_aux_weight = args.training.src_aux_weight | |
self.tar_aux_weight = args.training.tar_aux_weight | |
self.tar_entropy_weight = args.training.tar_entropy_weight | |
# algorithmic params | |
self.n_steps = args.training.n_steps | |
self.warmup_epochs = self.args.training.warmup_epochs | |
self.sib_lr = self.args.training.sib_lr | |
self.sib_update = False | |
self.logger.info('warmup epochs: {}'.format(self.warmup_epochs)) | |
def train(self, src_loader, tar_loader, val_loader, test_loader): | |
num_batches = len(src_loader) | |
print('Number of batches: %d' % num_batches) | |
print_freq = max(num_batches // self.args.training.num_print_epoch, 1) | |
i_iter = self.start_iter | |
start_epoch = i_iter // num_batches | |
num_epochs = self.args.training.num_epochs | |
best_acc = 0 | |
best_acc_nonsib = 0 | |
for epoch in range(start_epoch, num_epochs): | |
self.model.train() | |
batch_time = AverageMeter() | |
losses = AverageMeter() | |
if epoch > self.warmup_epochs: | |
self.sib_update = True | |
self.logger.info('Epoch {:>2} : sib_update enabled'.format(epoch)) | |
for it, (src_batch, tar_batch) in enumerate(zip(src_loader, itertools.cycle(tar_loader))): | |
t = time.time() | |
lr = self.sib_lr # self.optimizer.param_groups[0]['lr'] | |
if isinstance(src_batch, list): | |
src = src_batch[0] # data, dataset_idx | |
# print('dataset_idx', src_batch[1]) | |
else: | |
src = src_batch | |
src = to_device(src, self.device) | |
src_imgs = src['images'] | |
src_cls_lbls = src['class_labels'] | |
src_aux_lbls = src['aux_labels'] | |
self.optimizer.zero_grad() | |
src_aux_loss, src_task_loss, _, _ = self.model(src_imgs, lr, src_aux_lbls, src_cls_lbls, sib_update=self.sib_update) | |
src_loss = src_task_loss + src_aux_loss * self.src_aux_weight | |
loss = src_loss | |
if self.task_type == 'DA': | |
tar = to_device(tar_batch, self.device) | |
tar_imgs = tar['images'] | |
tar_aux_lbls = tar['aux_labels'] | |
# class label is not available for target domain | |
tar_aux_loss, tar_task_loss, _, _ = self.model(tar_imgs, lr, tar_aux_lbls, sib_update=self.sib_update) | |
tar_loss = tar_task_loss * self.tar_entropy_weight + tar_aux_loss * self.tar_aux_weight | |
loss += tar_loss | |
loss.backward() | |
self.optimizer.step() | |
losses.update(loss.item(), src_imgs.size(0)) | |
# measure elapsed time | |
batch_time.update(time.time() - t) | |
i_iter += 1 | |
if i_iter % print_freq == 0: | |
if self.task_type == 'DA': | |
print_string = 'Epoch {:>2} | iter {:>4} | loss: {:2.3f} | src_class: {:2.3f} | src_aux: {:2.3f} | tar_aux: {:2.3f} | {:.1f} s/it' | |
self.logger.info(print_string.format(epoch, i_iter, loss.item(), src_task_loss.item(), | |
src_aux_loss.item(), tar_aux_loss.item(), batch_time.avg)) | |
else: | |
print_string = 'Epoch {:>2} | iter {:>4} | loss: {:2.3f} | src_class: {:2.3f} | src_aux: {:2.3f} | {:.1f} s/it' | |
self.logger.info(print_string.format(epoch, i_iter, loss.item(), src_task_loss.item(), | |
src_aux_loss.item(), batch_time.avg)) | |
# adjust learning rate | |
self.scheduler.step() | |
del src_task_loss, src_aux_loss | |
if self.task_type == 'DA': | |
del tar_task_loss, tar_aux_loss | |
# validation | |
# if val_loader is not None and len(val_loader) > 0: | |
# self.logger.info('validating...') | |
# class_acc = self.test(val_loader) | |
# self.writer.add_scalar('val/class_acc', class_acc, i_iter) | |
if test_loader is not None: | |
self.logger.info('testing...') | |
class_acc = self.test(test_loader) | |
self.writer.add_scalar('test/class_acc', class_acc, i_iter) | |
if class_acc > best_acc: | |
best_acc = class_acc | |
self.save(self.args.model_dir, i_iter, is_best=True) | |
if not self.sib_update: | |
best_acc_nonsib = best_acc | |
self.logger.info('Best testing accuracy: {:.2f} % (non-sib), {:.2f} % (sib)'.format(best_acc_nonsib, best_acc)) | |
self.logger.info('Best testing accuracy: {:.2f} % (non-sib), {:.2f} % (sib)'.format(best_acc_nonsib, best_acc)) | |
self.logger.info('Finished Training.') | |
def save(self, path, i_iter, is_best=False): | |
state = {"iter": i_iter + 1, | |
"model_state": self.model.state_dict(), | |
"optimizer_state": self.optimizer.state_dict(), | |
"scheduler_state": self.scheduler.state_dict(), | |
} | |
if is_best: | |
save_path = os.path.join(path, 'model_best.pth') | |
else: | |
save_path = os.path.join(path, 'model_{:06d}.pth'.format(i_iter)) | |
self.logger.info('Saving model to %s' % save_path) | |
torch.save(state, save_path) | |
def load(self, path): | |
checkpoint = torch.load(path) | |
self.model.load_state_dict(checkpoint['model_state']) | |
self.logger.info('Loaded model from: ' + path) | |
if self.args.mode == 'train': | |
self.model.load_state_dict(checkpoint['model_state']) | |
self.optimizer.load_state_dict(checkpoint['optimizer_state']) | |
self.scheduler.load_state_dict(checkpoint['scheduler_state']) | |
self.start_iter = checkpoint['iter'] | |
self.logger.info('Start iter: %d ' % self.start_iter) | |
def test(self, val_loader): | |
val_loader_iterator = iter(val_loader) | |
num_val_iters = len(val_loader) | |
tt = tqdm(range(num_val_iters), total=num_val_iters, desc="Validating") | |
class_correct = 0 | |
aux_correct = 0 | |
total = 0 | |
# It is important to set feat_net in eval mode | |
# the performance will be very low if feat_net is in train mode | |
# mainly due to the changed BN layers | |
model = deepcopy(self.model) | |
model.feat_net.eval() | |
for cur_it in tt: | |
data = next(val_loader_iterator) | |
if isinstance(data, list): | |
data = data[0] | |
# Get the inputs | |
data = to_device(data, self.device) | |
imgs = data['images'] | |
cls_lbls = data['class_labels'] | |
aux_lbls = data['aux_labels'] | |
lr = self.sib_lr | |
if imgs.size(0) > 1: | |
sib_update = self.sib_update | |
else: | |
sib_update = False | |
_, _, aux_logits, cls_logits = model(imgs, lr, aux_lbls, sib_update=sib_update) | |
_, cls_pred = cls_logits.max(dim=1) | |
_, aux_pred = aux_logits.max(dim=1) | |
class_correct += torch.sum(cls_pred == cls_lbls.data) | |
aux_correct += torch.sum(aux_pred == aux_lbls.data) | |
total += imgs.size(0) | |
tt.close() | |
del model | |
aux_acc = 100 * float(aux_correct) / total | |
class_acc = 100 * float(class_correct) / total | |
self.logger.info('{} aux_acc: {:.2f} %, class_acc: {:.2f} %'.format(self.args.exp_name, aux_acc, class_acc)) | |
return class_acc |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment