Skip to content

Instantly share code, notes, and snippets.

@Franklin-Yao
Last active June 7, 2020 18:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Franklin-Yao/f34fb4f83090521b2149196f21643f3a to your computer and use it in GitHub Desktop.
Save Franklin-Yao/f34fb4f83090521b2149196f21643f3a to your computer and use it in GitHub Desktop.
Lightning and 16-bit precision, range test
Check my git repo styleMix
import torch.nn as nn
from utils import Model_type, euclidean_dist
from torch.nn import functional as F
import torch
from pytorch_lightning.core.lightning import LightningModule
import numpy as np
from collections import OrderedDict
from time import time
from utils import count_acc
# --- conventional supervised training ---
class BaselineTrain(LightningModule):
def __init__(self, encoder, args, loss_type = 'softmax'):
super().__init__()
self.encoder = encoder
self.args = args
self.epoch = 0
self.old_time = time()
if args.model_type is Model_type.ResNet12:
final_feat_dim = 640
else:
pass
if args.dataset == 'MiniImageNet':
n_class = 64
else:
pass
self.classifier = nn.Linear(final_feat_dim, n_class)
self.classifier.bias.data.fill_(0)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self,data, mode='train'):
args = self.args
if mode not in ['train']:
data = data.view(args.n_way * (args.n_shot + args.n_query), *data.size()[2:])
feature = self.encoder(data)
feature = feature.view(args.n_way, args.n_shot + args.n_query, -1)
z_support = feature[:, :args.n_shot]
z_query = feature[:, args.n_shot:]
proto = z_support.view(args.n_way, args.n_shot, -1).mean(1)
z_query = z_query.contiguous().view(args.n_way * args.n_query, -1)
scores = -euclidean_dist(z_query, proto) / self.args.temperature
else:
feature = self.encoder(data)
scores = self.classifier(feature)
return scores
def training_step(self, batch, batch_idx):
args = self.args
data, index_label = batch[0].cuda(), batch[1].cuda()
logits = self(data, 'train')
label = index_label
loss = F.cross_entropy(logits, label)
print_freq = 20
if (batch_idx+1)%print_freq == 0:
print('Epoch {}, {}/{}, loss={:.4f}'.format(self.epoch, batch_idx, 300, loss.item()))
return {'loss':loss}
def training_epoch_end(self, outputs):
avg_loss = np.mean([x['loss'].item() for x in outputs])
# print('Train loss={:.4f}'.format(avg_loss))
return {'avg loss': avg_loss}
def validation_step(self, batch, batch_idx):
args = self.args
data, index_label = batch[0].cuda(), batch[1].cuda()
label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query))
label = label.cuda()
logits = self(data, mode='val')
loss = F.cross_entropy(logits, label)
acc = count_acc(logits, label)
return {'val_acc': acc, 'val_loss':loss}
def validation_epoch_end(self, outputs):
self.epoch = self.epoch+1
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_acc = np.mean([x['val_acc'] for x in outputs])
print('Validation loss={:.4f} acc={:.4f}, time={:.3f}:'.format(avg_loss.item(), avg_acc, time()-self.old_time))
self.old_time = time()
return {'val_acc':avg_acc, 'val_loss':avg_loss}
def configure_optimizers(self):
args = self.args
from torch.optim import SGD, lr_scheduler
optimizer = SGD(self.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=5 * 1e-4)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1)
return [optimizer], [scheduler]
import os
import os.path as osp
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from data.datamgr import SimpleDataManager, SetDataManager
from methods.StyleMix import StyleMix
from utils import Averager, Timer, count_acc,compute_confidence_interval, Model_type, Method_type,\
save_model, load_pretrained_weights, init, resume_model
def train_one_epoch(model, scheduler, optimizer, args, train_loader, label, writer, epoch):
# for i in range(len(train_loader)):
# scheduler.step()
# return
model.train()
print_freq = 10
for i, batch in enumerate(train_loader):
data, index_label = batch[0].cuda(), batch[1].cuda()
if args.method_type is Method_type.style:
logits, logits1 = model(data, 'train')
loss = F.cross_entropy(logits, label) + F.cross_entropy(logits1, label)
if args.exp_tag in ['same_labels']:
p, q = F.softmax(logits, dim=1), F.softmax(logits1, dim=1)
loss1 = F.kl_div(p.log(),q, reduction='batchmean') \
+ F.kl_div(q.log(), p, reduction='batchmean')
loss = loss + loss1
else:
logits = model(data, 'train')
if args.method_type is Method_type.baseline:
label = index_label
loss = F.cross_entropy(logits, label)
acc = count_acc(logits, label)
if i % print_freq == print_freq - 1:
if args.exp_tag in ['same_labels']:
print('epoch {}, train {}/{}, loss={:.4f}, KL_loss={:.4f}, acc={:.4f}'.format(epoch, i,
len(train_loader),
loss.item(), loss1.item(),
acc))
else:
print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format(epoch, i, len(train_loader), loss.item(), acc))
if writer is not None:
writer.add_scalar('loss', loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
def val(model, args, val_loader, label):
model.eval()
vl = Averager()
va = Averager()
with torch.no_grad():
for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)):
data, index_label = batch[0].cuda(), batch[1].cuda()
logits = model(data, mode = 'val')
loss = F.cross_entropy(logits, label)
acc = count_acc(logits, label)
vl.add(loss.item())
va.add(acc)
vl = vl.item()
va = va.item()
return vl, va
def test(model, label, args, few_shot_params):
if args.debug:
n_test = 10
print_freq = 2
else:
n_test = 1000
print_freq = 100
test_file = args.dataset_dir + 'test.json'
test_datamgr = SetDataManager(args.exp_tag, test_file, args.dataset_dir, args.image_size,
mode = 'val',n_episode = n_test ,**few_shot_params)
loader = test_datamgr.get_data_loader(aug=False)
test_acc_record = np.zeros((n_test,))
warmup_state = torch.load(osp.join(args.checkpoint_dir, 'max_acc' + '.pth'))['params']
model.load_state_dict(warmup_state, strict=False)
model.eval()
ave_acc = Averager()
with torch.no_grad():
for i, batch in enumerate(loader, 1):
data, index_label = batch[0].cuda(), batch[1].cuda()
logits = model(data, 'test')
acc = count_acc(logits, label)
ave_acc.add(acc)
test_acc_record[i - 1] = acc
if i % print_freq == 0:
print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))
m, pm = compute_confidence_interval(test_acc_record)
# print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc'],
# ave_acc.item()))
print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
acc_str = '%4.2f' % (m * 100)
with open(args.save_dir + '/result.txt', 'a') as f:
f.write('%s %s\n' % (acc_str, args.name))
def main():
timer = Timer()
args, writer = init()
if args.exp_tag in ['sen']:
if args.test:
from sen.main_base import base_test
return base_test(args)
else:
from sen.main_base import base_train
return base_train(args)
train_file = args.dataset_dir + 'train.json'
val_file = args.dataset_dir + 'val.json'
few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query)
n_episode = 10 if args.debug else 100
if args.method_type is Method_type.baseline:
train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=128)
train_loader = train_datamgr.get_data_loader(aug = True)
else:
train_datamgr = SetDataManager(args.exp_tag, train_file, args.dataset_dir, args.image_size,
n_episode=n_episode, mode='train', **few_shot_params)
train_loader = train_datamgr.get_data_loader(aug=True)
val_datamgr = SetDataManager(args.exp_tag, val_file, args.dataset_dir, args.image_size,
n_episode=n_episode,mode = 'val', **few_shot_params)
val_loader = val_datamgr.get_data_loader(aug=False)
if args.model_type is Model_type.ConvNet:
pass
elif args.model_type is Model_type.ResNet12:
# from networks.resnet import resnet12
# encoder = resnet12(exp_tag=args.exp_tag)
from networks.sen_backbone import ResNet12
encoder = ResNet12(args.exp_tag)
else:
raise ValueError('')
if args.method_type is Method_type.baseline:
# from methods.baselinetrain import BaselineTrain
# model = BaselineTrain(encoder, args)
from lightning.baseline import BaselineTrain
model = BaselineTrain(encoder, args)
elif args.method_type is Method_type.protonet:
from methods.protonet import ProtoNet
model = ProtoNet(encoder, args)
elif args.method_type is Method_type.style:
model = StyleMix(encoder, args, dropout=0.5)
else:
raise ValueError('')
model = model.cuda()
os.environ["KMP_WARNINGS"] = "FALSE"
import warnings
warnings.filterwarnings("ignore")
from pytorch_lightning import Trainer
trainer = Trainer(gpus=1, default_root_dir=args.checkpoint_dir, max_epochs=args.max_epoch,
val_percent_check=1.0,fast_dev_run=False, profiler=False, progress_bar_refresh_rate=0)
print('len of dataloader: '+str(len(train_loader)))
# lr_finder = trainer.lr_find(model, train_dataloader=train_loader, val_dataloaders=val_loader)
# args.lr = lr_finder.suggestion()
# print('learning rate suggested by lightning: ' + str(args.lr))
# model.hparams.lr = args.lr
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)
return
from torch.optim import SGD, lr_scheduler
optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5 * 1e-4)
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_loader), epochs=args.max_epoch)
# scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1)
# optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
args.ngpu = torch.cuda.device_count()
torch.backends.cudnn.benchmark = True
model = model.cuda()
label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query))
label = label.cuda()
if args.test:
test(model, label, args, few_shot_params)
return
if args.resume:
resume_OK = resume_model(model, optimizer, args)
else:
resume_OK = False
if (not resume_OK) and (args.warmup is not None):
load_pretrained_weights(model, args)
max_acc = 0.0
if args.debug:
args.max_epoch = args.start_epoch + 1
for epoch in range(args.start_epoch, args.max_epoch):
print('learning rate: '+str(optimizer.param_groups[0]['lr']))
train_one_epoch(model, scheduler, optimizer, args, train_loader, label,writer, epoch)
# continue
vl, va = val(model, args, val_loader, label)
if writer is not None:
writer.add_scalar('data/val_acc', float(va), epoch)
print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))
if va >= max_acc:
max_acc = va
print('saving the best model! acc={:.4f}'.format(va))
save_model(model, optimizer, args, epoch, 'max_acc')
save_model(model, optimizer, args, epoch, 'epoch-last')
if epoch != 0:
print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
# return
if writer is not None:
writer.close()
# Test Phase
test(model, label, args, few_shot_params)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment