Skip to content

Instantly share code, notes, and snippets.

@luistelmocosta
Created February 1, 2019 04:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save luistelmocosta/61d0ab920bcc2d523fefb099c6395f4e to your computer and use it in GitHub Desktop.
Save luistelmocosta/61d0ab920bcc2d523fefb099c6395f4e to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import os
import numpy as np
import argparse
import csv
import shutil
import sys
sys.path.append("..")
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from models import *
from data_set.dataset_train import get_dataset
from utils import progress_bar
from utils import stats
from user_define import config as cf
from user_define import hyperparameter as hp
import math
from pytorchtools import EarlyStopping
from models.senet.senet import se_resnet50
from models.inceptionv3 import inception_v3
#net = senet154()
# Basic Parameters Init
BEST_AUC = 0
THRESHOLD = 0.5
START_EPOCH = 0
""" LR_DECAY = 0
LR_CHANCE = 0 """
CUR_EPOCH = []
CUR_LOSS = []
CUR_VAL_ACC = []
CUR_TRA_ACC = []
CUR_LR = []
CUR_TRA_LOSS = []
USE_CUDA = torch.cuda.is_available()
# Parser Init
parser = argparse.ArgumentParser(description='Camelyon17 Training' )
parser.add_argument('--lr', default=hp.default_lr, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
# Data loading
print('==> Preparing data..')
trans_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
])
trans_test = transforms.Compose([
transforms.ToTensor(),
])
if hp.mining == True:
#trainset, valset, subtestset, testset, miningset = get_dataset(trans_train, trans_test, hp.train_num, hp.val_num, hp.subtest_num, hp.train_ratio, hp.mining)
miningloader = torch.utils.data.DataLoader(miningset, batch_size=hp.batch_size,
shuffle=True, num_workers=hp.num_workers)
else:
trainset, valset = get_dataset(trans_train, trans_test, hp.train_num, hp.val_num, hp.subtest_num)
#subtestset = get_dataset(trans_train, trans_test, hp.train_num, hp.val_num, hp.subtest_num)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=hp.batch_size,
shuffle=True, num_workers=hp.num_workers, drop_last=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=5,
shuffle=False, num_workers=hp.num_workers, drop_last=True)
#subtestloader = torch.utils.data.DataLoader(subtestset, batch_size=hp.batch_size,
#shuffle=False, num_workers=hp.num_workers, drop_last=True)
#testloader = torch.utils.data.DataLoader(testset, batch_size=hp.batch_size,
#shuffle=False, num_workers=hp.num_workers)
print('Data loading END')
# Model
if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/chance/ckpt.t7')
net = checkpoint['net']
BEST_AUC = checkpoint['auc']
START_EPOCH = checkpoint['epoch']
THRESHOLD = checkpoint['threshold']
if checkpoint['lr'] < 1e-5:
args.lr = 1e-5
else:
args.lr = checkpoint['lr']
else:
print('==> Building model..')
#net = resnet18()
#net = resnet34()
#net = resnet50()
#net = resnet101()
#net = resnet152()
#net = densenet121()
#net = densenet161()
#net = densenet201()
#net = se_resnet101()
#net = senet154()
net = inception_v3()
num_ftrs = net.fc.in_features
print(num_ftrs)
net.aux_logits=False
#net = se_resnet50()
if USE_CUDA:
if args.resume == False:
net.cuda()
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True
# Optimization, Loss Function Init
#criterion = nn.BCELoss()
criterion = nn.BCEWithLogitsLoss()
#criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=hp.default_lr, momentum=hp.momentum, weight_decay=hp.weight_decay)
#optimizer = optim.Adam(net.parameters(), lr=hp.default_lr, weight_decay=hp.weight_decay)
def train(epoch, wrong_save=False):
''' trian net using patches of slide.
save csv file that has patch file name predicted incorrectly.
Args:
epoch (int): current epoch
wrong_save (bool): If True, save the csv file that has patch file name
predicted incorrectly
'''
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
wrong_list = []
final_train_loss = 0
# to track the training loss as the model trains
train_losses = []
for batch_idx, (inputs, targets, filename) in enumerate(trainloader):
if USE_CUDA:
inputs = inputs.cuda()
targets = torch.FloatTensor(np.array(targets).astype(float)).cuda()
optimizer.zero_grad()
inputs, targets = Variable(inputs), Variable(targets)
outputs = net(inputs)
batch_size = targets.shape[0]
#_, preds = torch.max(outputs, 1)
loss = criterion(outputs, targets)
#loss2 = criterion(aux_outputs, targets)
#loss = loss1 + 0.4*loss2
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
correct += torch.sum(preds == targets.data)
total += targets.size(0)
train_losses.append(running_loss)
if wrong_save == True:
for idx in range(len(filename_list)):
if outputs.data[idx] != targets.data[idx]:
wrong_name = filename_list[idx]
wrong_list.append(wrong_name)
train_loss = np.average(train_losses)
acc = correct.double() / total
print(acc)
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss, 100.*correct/total, correct, total))
if wrong_save == True:
wrong_csv = open(cf.wrong_path+'wrong_data_epoch'+str(epoch)+'.csv','w',encoding='utf-8')
wr = csv.writer(wrong_csv)
for name in wrong_list:
wr.writerow([name])
wrong_csv.close()
print("Train Loss: ", train_loss)
CUR_TRA_ACC.append(100.*correct/total)
CUR_TRA_LOSS.append(train_loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment