Last active
March 7, 2020 15:45
-
-
Save farazahmeds/0bddc1e5dd5f7cc501c12c51b84fc0ea to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.utils | |
import torch.utils.data.dataset | |
import torch.utils.data.dataloader | |
from torch.optim import lr_scheduler | |
from torch import utils | |
import numpy as np | |
import torchvision | |
from torchvision import datasets, models, transforms | |
from torch.autograd import Variable | |
import torch.multiprocessing | |
from scipy.ndimage import zoom | |
from livelossplot import PlotLosses | |
import ipykernel.pylab.backend_inline | |
import matplotlib.pyplot as plt | |
from skimage import exposure | |
import time | |
import pandas as pd | |
import os | |
import cv2 | |
from apex import amp | |
learning_rate = 0.0001 | |
num_epochs = 100 | |
num_classes = 3 | |
batch_size = 14 | |
opt_level = 'O1' | |
num_workers = 4 | |
torch.manual_seed(0) | |
np_load_old = np.load | |
np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k) | |
arrays = np.array(np.load('/scratch/faraz/Thesis/clinical_data_array_128x128.npy')) | |
images = arrays[:,1] | |
#interpolate to make the deth between the slices equal | |
scans = [] | |
for dataz in images: | |
if len(dataz[:,0]) == 54: | |
scans.append(zoom(dataz, (1.04, 1, 1))) | |
elif len(dataz[:,0]) == 55: | |
scans.append(zoom(dataz, (1.02, 1, 1))) | |
elif len(dataz[:,0]) == 57: | |
scans.append(zoom(dataz, (0.99, 1, 1))) | |
elif len(dataz[:,0]) == 58: | |
scans.append(zoom(dataz, (0.97, 1, 1))) | |
elif len(dataz[:,0]) == 56: | |
scans.append(zoom(dataz, (1, 1, 1))) | |
elif len(dataz[:,0]) == 58: | |
scans.append(zoom(dataz, (0.6, 1, 1))) | |
else: | |
scans.append(dataz) | |
#Histogram Equalization | |
work = [] | |
for i in scans: | |
new = [] | |
for j in i: | |
j = exposure.equalize_hist(j) | |
new.append(j) | |
work.append(new) | |
print (np.asarray(work).shape) | |
scanimages = np.stack(work) #work = histogram applied images | |
labelss= np.stack(arrays[:,2]) | |
labels = np.argmax(labelss,axis=1) | |
tensor_x1 = torch.Tensor(scanimages) | |
tensor_x = tensor_x1.unsqueeze(1) | |
tensor_y = torch.Tensor(labels).long() | |
# print (tensor_y) | |
my_dataset = utils.data.TensorDataset(tensor_x,tensor_y)# create your datset | |
# print (len(my_dataset)) | |
train_dataset, test_dataset = torch.utils.data.random_split(my_dataset, [32,9]) | |
train_loader = utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) | |
test_loader = utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) | |
dataiter = iter(train_loader) | |
# [(N-F+2P)/stride + 1] | |
class CNNModel(nn.Module): | |
def __init__(self): | |
super(CNNModel, self).__init__() | |
# Convolution 1 | |
self.cnn1 = nn.Conv3d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2) | |
self.relu1 = nn.Softmax() | |
# self.dropout1 = nn.Dropout(p=0.2) | |
self.maxpool1 = nn.MaxPool3d(kernel_size=2) | |
# Convolution 2 | |
self.cnn2 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2) | |
self.relu2 = nn.Softmax() | |
self.dropout2 = nn.Dropout(p=0.1) | |
self.maxpool2 = nn.MaxPool3d(kernel_size=2) | |
# Convolution 3 | |
self.cnn3 = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2) | |
self.relu3 = nn.Softmax() | |
self.dropout3 = nn.Dropout(p=0.1) | |
self.maxpool3 = nn.MaxPool3d(kernel_size=2) | |
# Convolution 4 | |
self.cnn4 = nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=2) | |
self.relu4 = nn.Softmax() | |
# self.dropout3 = nn.Dropout(p=0.1) | |
self.maxpool4 = nn.MaxPool3d(kernel_size=2) | |
# # Convolution 5 | |
# self.cnn5 = nn.Conv3d(in_channels=200, out_channels=200, kernel_size=3, stride=1, padding=2) | |
# self.relu5 = nn.ReLU() | |
# # self.dropout3 = nn.Dropout(p=0.1) | |
# self.maxpool5 = nn.MaxPool3d(kernel_size=2) | |
# # | |
# | |
# # Dropout for regularization | |
self.dropout4 = nn.Dropout(p=0.5) | |
# Fully Connected 1 | |
# self.fc1 = nn.Linear(295936, 295936) | |
# | |
# self.relu6 = nn.ReLU() | |
self.fc2 = nn.Linear(103680, 3) | |
def forward(self, x): | |
# Convolution 1 | |
out = self.cnn1(x) | |
out = self.relu1(out) | |
# out = self.dropout1(out) | |
# print ('conv1', out.size()) | |
out = self.maxpool1(out) | |
# print('maxpool 1', out.size()) | |
# Convolution 2 | |
out = self.cnn2(out) | |
out = self.relu2(out) | |
out = self.dropout2(out) | |
# print('conv2', out.size()) | |
out = self.maxpool2(out) | |
out = self.cnn3(out) | |
out = self.relu3(out) | |
out = self.dropout3(out) | |
out = self.maxpool3(out) | |
# print('maxpool 2', out.size()) | |
out = self.cnn4(out) | |
out = self.relu4(out) | |
# out = self.dropout4(out) | |
out = self.maxpool4(out) | |
# out = self.cnn5(out) | |
# out = self.relu5(out) | |
# # out = self.dropout5(out) | |
# out = self.maxpool5(out) | |
# Resize | |
out = out.view(out.size(0), -1) | |
# print('flattening', out.size()) | |
# Dropout | |
out = self.dropout4(out) | |
# print('dropout', out.size()) | |
# Fully connected 1 | |
# out = self.fc1(out) | |
# out = self.relu6(out) | |
out = self.fc2(out) | |
# print('fully connected 1', out.size()) | |
# print ('-----------------------------') | |
return out | |
model = CNNModel() | |
liveloss = PlotLosses() | |
model.to('cuda:0') | |
weights = [0.5,0.7,1.0] | |
class_weights = torch.FloatTensor(weights).cuda() | |
criterion = nn.CrossEntropyLoss(weight=class_weights) | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) | |
# | |
# Train the model | |
# | |
# tb = SummaryWriter() | |
# logs = {} | |
# total_correct = 0 | |
# total_loss = 0 | |
# total_images = 0 | |
# total_val_loss = 0 | |
# | |
# correct = 0 | |
# total = 0 | |
for epoch in range(num_epochs): | |
logs = {} | |
total_correct = 0 | |
total_loss = 0 | |
total_images = 0 | |
total_val_loss = 0 | |
model.train() | |
for i, (data, target) in enumerate(train_loader): | |
images = data.to('cuda:0') | |
labels = target.to('cuda:0') | |
# Forward propagation | |
outputs = model(images) | |
# Calculating loss with softmax to obtain cross entropy loss | |
# loss = criterion(outputs, labels) | |
loss = criterion(outputs, labels) #....> | |
optimizer.zero_grad() | |
# Backward prop | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
# loss.backward() | |
# Updating gradients | |
optimizer.step() | |
# Total number of labels | |
total_images+= labels.size(0) | |
# Obtaining predictions from max value | |
_, predicted = torch.max(outputs.detach(), 1) | |
# Calculate the number of correct answers | |
correct = (predicted == labels).sum().item() | |
total_correct+=correct | |
total_loss+=loss.item() | |
logs['log loss'] = total_loss / total_images | |
logs['Accuracy'] = ((total_correct / total_images) * 100) | |
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%' | |
.format(epoch + 1, num_epochs, i + 1, len(test_loader), (total_loss / total_images), | |
(total_correct / total_images) * 100)) | |
# Testing the model | |
model.eval() | |
with torch.no_grad(): | |
correct = 0 | |
total = 0 | |
total_losss =0 | |
for data, target in test_loader: | |
images = data.to('cuda:0') | |
labels = target.to('cuda:0') | |
outputs = model(images) | |
_, predicted = torch.max(outputs.detach(), 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
total_losss += loss.item() | |
accuracy = correct / total | |
print('Test Accuracy of the model: {} %'.format(100 * correct / total)) | |
logs['val_' + 'log loss'] = total_losss/total | |
logs['val_' + 'Accuracy'] = ((correct / total) * 100) | |
# logs['accuracy'] = accuracy | |
liveloss.update(logs) | |
liveloss.draw() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment