Skip to content

Instantly share code, notes, and snippets.

@hushell
Created March 25, 2021 09:03
Show Gist options
  • Save hushell/e14b79bcb3cb3e14cc294628e27eadc4 to your computer and use it in GitHub Desktop.
Save hushell/e14b79bcb3cb3e14cc294628e27eadc4 to your computer and use it in GitHub Desktop.
zero-shot SIB for domain adaptation
import os
import time
import math
import itertools
import numpy as np
from PIL import Image
from tqdm import tqdm
from copy import deepcopy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
# custom modules
from schedulers import get_scheduler
from optimizers import get_optimizer
from networks import get_network
from utils.metrics import AverageMeter
from utils.utils import to_device
# summary
from tensorboardX import SummaryWriter
def entropy_loss(x):
return torch.sum(-F.softmax(x, 1) * F.log_softmax(x, 1), 1).mean()
class DNINet(nn.Module):
def __init__(self, input_dims, dni_hidden_size=1024):
super(DNINet, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(input_dims, dni_hidden_size),
nn.ReLU(),
nn.BatchNorm1d(dni_hidden_size)
)
self.layer2 = nn.Sequential(
nn.Linear(dni_hidden_size, dni_hidden_size),
nn.ReLU(),
nn.BatchNorm1d(dni_hidden_size)
)
self.layer3 = nn.Linear(dni_hidden_size, input_dims)
# self.apply(weight_init)
for m in self.modules():
if isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
nn.init.normal_(m.bias)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
return out
class SIBNet(nn.Module):
def __init__(self, args):
super(SIBNet, self).__init__()
self.args = args
self.n_steps = args.training.n_steps
# meta-model: feature-net & synthetic-grad-net
self.feat_net = get_network(args.network.arch)(output='feature')
self.feat_dim = self.feat_net.feat_dim
self.dni = DNINet(args.n_classes)
# meta-model: classification
self.theta = nn.Parameter(torch.FloatTensor(args.n_classes, self.feat_dim))
nn.init.kaiming_uniform_(self.theta, a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.theta)
bound = 1 / math.sqrt(fan_in)
self.bias_cls = nn.Parameter(torch.FloatTensor(args.n_classes))
nn.init.uniform_(self.bias_cls, -bound, bound)
self.up_dim = self.feat_dim
self.up = nn.Parameter(torch.FloatTensor(self.up_dim, args.n_classes))
nn.init.kaiming_normal_(self.up)
# meta-model: auxilary self-sup (e.g. rotation)
self.rho = nn.Parameter(torch.FloatTensor(args.aux_classes+1, self.feat_dim + self.up_dim))
nn.init.kaiming_uniform_(self.rho, a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.rho)
bound = 1 / math.sqrt(fan_in)
self.bias_aux = nn.Parameter(torch.FloatTensor(args.aux_classes+1))
nn.init.uniform_(self.bias_aux, -bound, bound)
# loss
self.class_loss_func = nn.CrossEntropyLoss()
def multitask_predict(self, features, cls_weights, aux_weights):
'''
(n x nFeat, nClass x nFeat, nRot x nFeat) -> n x nClass x nRot
'''
tensor_cls = F.linear(features, cls_weights, self.bias_cls)
tensor_up = torch.einsum('nk,dk->nd', tensor_cls, self.up)
tensor_cat = torch.cat([features, tensor_up], dim=1) # n(d*2)
tensor_aux = F.linear(tensor_cat, aux_weights, self.bias_aux)
return tensor_cls, tensor_aux
def self_sup_update(self, aux_labels, theta, features, lr, n_steps=1):
theta = theta.clone() # TODO: check this doesn't affect theta in meta-model
for t in range(n_steps):
cls_scores, aux_scores = self.multitask_predict(features, theta, self.rho.detach())
aux_loss = self.class_loss_func(aux_scores, aux_labels)
grad = torch.autograd.grad([aux_loss], [theta],
create_graph=True, retain_graph=True, only_inputs=True)[0] # nClass x nFeat
# perform GD
theta = theta - lr * grad
return theta
def synthetic_grad_update(self, theta, features, lr, n_steps=3):
theta = theta.clone()
for t in range(n_steps):
cls_scores, aux_scores = self.multitask_predict(features, theta, self.rho.detach())
grad_logit = self.dni(cls_scores) # n x nClass
grad = torch.autograd.grad([cls_scores], [theta],
grad_outputs=[grad_logit],
create_graph=True, retain_graph=True,
only_inputs=True)[0] # nClass x nFeat
# perform GD
theta = theta - lr * grad
return theta
def forward(self, img, lr, aux_labels, cls_labels = None, sib_update=True):
feat = self.feat_net(img)
# generate classification weights for this minibatch
theta = self.theta
if self.args.training.with_self_sup_update and sib_update:
theta = self.self_sup_update(aux_labels, theta,
feat.detach(), lr, n_steps=1)
if self.args.training.with_synthetic_grad_update and sib_update:
theta = self.synthetic_grad_update(theta, feat.detach(), lr, n_steps=self.n_steps)
cls_scores, aux_scores = self.multitask_predict(feat, theta, self.rho)
aux_loss = self.class_loss_func(aux_scores, aux_labels)
if cls_labels is None:
task_loss = entropy_loss(cls_scores[aux_labels==0])
else:
task_loss = self.class_loss_func(cls_scores, cls_labels)
return aux_loss, task_loss, aux_scores, cls_scores
class SIBModel:
def __init__(self, args, logger):
self.args = args
self.logger = logger
self.writer = SummaryWriter(args.log_dir)
cudnn.enabled = True
self.task_type = args.get('task_type', 'DA')
print('Task type: {}'.format(self.task_type))
# set up model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = SIBNet(args)
self.model = self.model.to(self.device)
# optimizer & scheduler & loss
if args.mode == 'train':
# set up optimizer, lr scheduler and loss functions
optimizer = get_optimizer(self.args.training.optimizer)
optimizer_params = {k: v for k, v in self.args.training.optimizer.items() if k != "name"}
self.optimizer = optimizer(self.model.parameters(), **optimizer_params)
self.scheduler = get_scheduler(self.optimizer, self.args.training.lr_scheduler)
self.start_iter = 0
# resume
if args.training.resume:
self.load(args.model_dir + '/' + args.training.resume)
elif args.mode == 'val':
self.load(os.path.join(args.model_dir, args.validation.model))
else:
self.load(os.path.join(args.model_dir, args.testing.model))
self.src_aux_weight = args.training.src_aux_weight
self.tar_aux_weight = args.training.tar_aux_weight
self.tar_entropy_weight = args.training.tar_entropy_weight
# algorithmic params
self.n_steps = args.training.n_steps
self.warmup_epochs = self.args.training.warmup_epochs
self.sib_lr = self.args.training.sib_lr
self.sib_update = False
self.logger.info('warmup epochs: {}'.format(self.warmup_epochs))
def train(self, src_loader, tar_loader, val_loader, test_loader):
num_batches = len(src_loader)
print('Number of batches: %d' % num_batches)
print_freq = max(num_batches // self.args.training.num_print_epoch, 1)
i_iter = self.start_iter
start_epoch = i_iter // num_batches
num_epochs = self.args.training.num_epochs
best_acc = 0
best_acc_nonsib = 0
for epoch in range(start_epoch, num_epochs):
self.model.train()
batch_time = AverageMeter()
losses = AverageMeter()
if epoch > self.warmup_epochs:
self.sib_update = True
self.logger.info('Epoch {:>2} : sib_update enabled'.format(epoch))
for it, (src_batch, tar_batch) in enumerate(zip(src_loader, itertools.cycle(tar_loader))):
t = time.time()
lr = self.sib_lr # self.optimizer.param_groups[0]['lr']
if isinstance(src_batch, list):
src = src_batch[0] # data, dataset_idx
# print('dataset_idx', src_batch[1])
else:
src = src_batch
src = to_device(src, self.device)
src_imgs = src['images']
src_cls_lbls = src['class_labels']
src_aux_lbls = src['aux_labels']
self.optimizer.zero_grad()
src_aux_loss, src_task_loss, _, _ = self.model(src_imgs, lr, src_aux_lbls, src_cls_lbls, sib_update=self.sib_update)
src_loss = src_task_loss + src_aux_loss * self.src_aux_weight
loss = src_loss
if self.task_type == 'DA':
tar = to_device(tar_batch, self.device)
tar_imgs = tar['images']
tar_aux_lbls = tar['aux_labels']
# class label is not available for target domain
tar_aux_loss, tar_task_loss, _, _ = self.model(tar_imgs, lr, tar_aux_lbls, sib_update=self.sib_update)
tar_loss = tar_task_loss * self.tar_entropy_weight + tar_aux_loss * self.tar_aux_weight
loss += tar_loss
loss.backward()
self.optimizer.step()
losses.update(loss.item(), src_imgs.size(0))
# measure elapsed time
batch_time.update(time.time() - t)
i_iter += 1
if i_iter % print_freq == 0:
if self.task_type == 'DA':
print_string = 'Epoch {:>2} | iter {:>4} | loss: {:2.3f} | src_class: {:2.3f} | src_aux: {:2.3f} | tar_aux: {:2.3f} | {:.1f} s/it'
self.logger.info(print_string.format(epoch, i_iter, loss.item(), src_task_loss.item(),
src_aux_loss.item(), tar_aux_loss.item(), batch_time.avg))
else:
print_string = 'Epoch {:>2} | iter {:>4} | loss: {:2.3f} | src_class: {:2.3f} | src_aux: {:2.3f} | {:.1f} s/it'
self.logger.info(print_string.format(epoch, i_iter, loss.item(), src_task_loss.item(),
src_aux_loss.item(), batch_time.avg))
# adjust learning rate
self.scheduler.step()
del src_task_loss, src_aux_loss
if self.task_type == 'DA':
del tar_task_loss, tar_aux_loss
# validation
# if val_loader is not None and len(val_loader) > 0:
# self.logger.info('validating...')
# class_acc = self.test(val_loader)
# self.writer.add_scalar('val/class_acc', class_acc, i_iter)
if test_loader is not None:
self.logger.info('testing...')
class_acc = self.test(test_loader)
self.writer.add_scalar('test/class_acc', class_acc, i_iter)
if class_acc > best_acc:
best_acc = class_acc
self.save(self.args.model_dir, i_iter, is_best=True)
if not self.sib_update:
best_acc_nonsib = best_acc
self.logger.info('Best testing accuracy: {:.2f} % (non-sib), {:.2f} % (sib)'.format(best_acc_nonsib, best_acc))
self.logger.info('Best testing accuracy: {:.2f} % (non-sib), {:.2f} % (sib)'.format(best_acc_nonsib, best_acc))
self.logger.info('Finished Training.')
def save(self, path, i_iter, is_best=False):
state = {"iter": i_iter + 1,
"model_state": self.model.state_dict(),
"optimizer_state": self.optimizer.state_dict(),
"scheduler_state": self.scheduler.state_dict(),
}
if is_best:
save_path = os.path.join(path, 'model_best.pth')
else:
save_path = os.path.join(path, 'model_{:06d}.pth'.format(i_iter))
self.logger.info('Saving model to %s' % save_path)
torch.save(state, save_path)
def load(self, path):
checkpoint = torch.load(path)
self.model.load_state_dict(checkpoint['model_state'])
self.logger.info('Loaded model from: ' + path)
if self.args.mode == 'train':
self.model.load_state_dict(checkpoint['model_state'])
self.optimizer.load_state_dict(checkpoint['optimizer_state'])
self.scheduler.load_state_dict(checkpoint['scheduler_state'])
self.start_iter = checkpoint['iter']
self.logger.info('Start iter: %d ' % self.start_iter)
def test(self, val_loader):
val_loader_iterator = iter(val_loader)
num_val_iters = len(val_loader)
tt = tqdm(range(num_val_iters), total=num_val_iters, desc="Validating")
class_correct = 0
aux_correct = 0
total = 0
# It is important to set feat_net in eval mode
# the performance will be very low if feat_net is in train mode
# mainly due to the changed BN layers
model = deepcopy(self.model)
model.feat_net.eval()
for cur_it in tt:
data = next(val_loader_iterator)
if isinstance(data, list):
data = data[0]
# Get the inputs
data = to_device(data, self.device)
imgs = data['images']
cls_lbls = data['class_labels']
aux_lbls = data['aux_labels']
lr = self.sib_lr
if imgs.size(0) > 1:
sib_update = self.sib_update
else:
sib_update = False
_, _, aux_logits, cls_logits = model(imgs, lr, aux_lbls, sib_update=sib_update)
_, cls_pred = cls_logits.max(dim=1)
_, aux_pred = aux_logits.max(dim=1)
class_correct += torch.sum(cls_pred == cls_lbls.data)
aux_correct += torch.sum(aux_pred == aux_lbls.data)
total += imgs.size(0)
tt.close()
del model
aux_acc = 100 * float(aux_correct) / total
class_acc = 100 * float(class_correct) / total
self.logger.info('{} aux_acc: {:.2f} %, class_acc: {:.2f} %'.format(self.args.exp_name, aux_acc, class_acc))
return class_acc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment