Created
February 1, 2019 04:25
-
-
Save luistelmocosta/61d0ab920bcc2d523fefb099c6395f4e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.backends.cudnn as cudnn | |
from torch.autograd import Variable | |
import torchvision | |
import torchvision.transforms as transforms | |
import os | |
import numpy as np | |
import argparse | |
import csv | |
import shutil | |
import sys | |
sys.path.append("..") | |
import matplotlib | |
matplotlib.use('agg') | |
import matplotlib.pyplot as plt | |
from models import * | |
from data_set.dataset_train import get_dataset | |
from utils import progress_bar | |
from utils import stats | |
from user_define import config as cf | |
from user_define import hyperparameter as hp | |
import math | |
from pytorchtools import EarlyStopping | |
from models.senet.senet import se_resnet50 | |
from models.inceptionv3 import inception_v3 | |
#net = senet154() | |
# Basic Parameters Init | |
BEST_AUC = 0 | |
THRESHOLD = 0.5 | |
START_EPOCH = 0 | |
""" LR_DECAY = 0 | |
LR_CHANCE = 0 """ | |
CUR_EPOCH = [] | |
CUR_LOSS = [] | |
CUR_VAL_ACC = [] | |
CUR_TRA_ACC = [] | |
CUR_LR = [] | |
CUR_TRA_LOSS = [] | |
USE_CUDA = torch.cuda.is_available() | |
# Parser Init | |
parser = argparse.ArgumentParser(description='Camelyon17 Training' ) | |
parser.add_argument('--lr', default=hp.default_lr, type=float, help='learning rate') | |
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') | |
args = parser.parse_args() | |
# Data loading | |
print('==> Preparing data..') | |
trans_train = transforms.Compose([ | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomVerticalFlip(), | |
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), | |
transforms.RandomGrayscale(p=0.1), | |
transforms.ToTensor(), | |
]) | |
trans_test = transforms.Compose([ | |
transforms.ToTensor(), | |
]) | |
if hp.mining == True: | |
#trainset, valset, subtestset, testset, miningset = get_dataset(trans_train, trans_test, hp.train_num, hp.val_num, hp.subtest_num, hp.train_ratio, hp.mining) | |
miningloader = torch.utils.data.DataLoader(miningset, batch_size=hp.batch_size, | |
shuffle=True, num_workers=hp.num_workers) | |
else: | |
trainset, valset = get_dataset(trans_train, trans_test, hp.train_num, hp.val_num, hp.subtest_num) | |
#subtestset = get_dataset(trans_train, trans_test, hp.train_num, hp.val_num, hp.subtest_num) | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=hp.batch_size, | |
shuffle=True, num_workers=hp.num_workers, drop_last=True) | |
valloader = torch.utils.data.DataLoader(valset, batch_size=5, | |
shuffle=False, num_workers=hp.num_workers, drop_last=True) | |
#subtestloader = torch.utils.data.DataLoader(subtestset, batch_size=hp.batch_size, | |
#shuffle=False, num_workers=hp.num_workers, drop_last=True) | |
#testloader = torch.utils.data.DataLoader(testset, batch_size=hp.batch_size, | |
#shuffle=False, num_workers=hp.num_workers) | |
print('Data loading END') | |
# Model | |
if args.resume: | |
# Load checkpoint. | |
print('==> Resuming from checkpoint..') | |
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' | |
checkpoint = torch.load('./checkpoint/chance/ckpt.t7') | |
net = checkpoint['net'] | |
BEST_AUC = checkpoint['auc'] | |
START_EPOCH = checkpoint['epoch'] | |
THRESHOLD = checkpoint['threshold'] | |
if checkpoint['lr'] < 1e-5: | |
args.lr = 1e-5 | |
else: | |
args.lr = checkpoint['lr'] | |
else: | |
print('==> Building model..') | |
#net = resnet18() | |
#net = resnet34() | |
#net = resnet50() | |
#net = resnet101() | |
#net = resnet152() | |
#net = densenet121() | |
#net = densenet161() | |
#net = densenet201() | |
#net = se_resnet101() | |
#net = senet154() | |
net = inception_v3() | |
num_ftrs = net.fc.in_features | |
print(num_ftrs) | |
net.aux_logits=False | |
#net = se_resnet50() | |
if USE_CUDA: | |
if args.resume == False: | |
net.cuda() | |
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) | |
cudnn.benchmark = True | |
# Optimization, Loss Function Init | |
#criterion = nn.BCELoss() | |
criterion = nn.BCEWithLogitsLoss() | |
#criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(net.parameters(), lr=hp.default_lr, momentum=hp.momentum, weight_decay=hp.weight_decay) | |
#optimizer = optim.Adam(net.parameters(), lr=hp.default_lr, weight_decay=hp.weight_decay) | |
def train(epoch, wrong_save=False): | |
''' trian net using patches of slide. | |
save csv file that has patch file name predicted incorrectly. | |
Args: | |
epoch (int): current epoch | |
wrong_save (bool): If True, save the csv file that has patch file name | |
predicted incorrectly | |
''' | |
print('\nEpoch: %d' % epoch) | |
net.train() | |
train_loss = 0 | |
correct = 0 | |
total = 0 | |
wrong_list = [] | |
final_train_loss = 0 | |
# to track the training loss as the model trains | |
train_losses = [] | |
for batch_idx, (inputs, targets, filename) in enumerate(trainloader): | |
if USE_CUDA: | |
inputs = inputs.cuda() | |
targets = torch.FloatTensor(np.array(targets).astype(float)).cuda() | |
optimizer.zero_grad() | |
inputs, targets = Variable(inputs), Variable(targets) | |
outputs = net(inputs) | |
batch_size = targets.shape[0] | |
#_, preds = torch.max(outputs, 1) | |
loss = criterion(outputs, targets) | |
#loss2 = criterion(aux_outputs, targets) | |
#loss = loss1 + 0.4*loss2 | |
loss.backward() | |
optimizer.step() | |
# statistics | |
running_loss += loss.item() * inputs.size(0) | |
correct += torch.sum(preds == targets.data) | |
total += targets.size(0) | |
train_losses.append(running_loss) | |
if wrong_save == True: | |
for idx in range(len(filename_list)): | |
if outputs.data[idx] != targets.data[idx]: | |
wrong_name = filename_list[idx] | |
wrong_list.append(wrong_name) | |
train_loss = np.average(train_losses) | |
acc = correct.double() / total | |
print(acc) | |
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' | |
% (train_loss, 100.*correct/total, correct, total)) | |
if wrong_save == True: | |
wrong_csv = open(cf.wrong_path+'wrong_data_epoch'+str(epoch)+'.csv','w',encoding='utf-8') | |
wr = csv.writer(wrong_csv) | |
for name in wrong_list: | |
wr.writerow([name]) | |
wrong_csv.close() | |
print("Train Loss: ", train_loss) | |
CUR_TRA_ACC.append(100.*correct/total) | |
CUR_TRA_LOSS.append(train_loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment