-
-
Save Sieyk/d3737b12d38ec41e0792f6eae7a57ffc to your computer and use it in GitHub Desktop.
Source code for pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function, division | |
import os | |
import torch | |
import math | |
import pandas as pd | |
from skimage import io, transform, color, img_as_float | |
# from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms, utils | |
# from torchvision.transforms import functional | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import cv2 | |
import torch.optim as optim | |
import random | |
# Ignore warnings | |
import warnings | |
warnings.filterwarnings("ignore") | |
OMP_NUM_THREADS = 11 | |
plt.ion() # interactive mode | |
# minerals_frame = pd.read_csv('dataset_processed.csv') | |
# | |
# n = 1 | |
# _img_name = minerals_frame.iloc[n, 0] | |
# _minerals = minerals_frame.iloc[n, 1:].as_matrix() | |
# _minerals = _minerals.astype('float').reshape(-1, 2) | |
# | |
# print('Image name: {}'.format(_img_name)) | |
# print('Minerals shape: {}'.format(minerals.shape)) | |
# print('First 4 Minerals: {}'.format(minerals[:4])) | |
# l1 = [1, 2, 3] | |
# l2 = [4, 5, 6] | |
# | |
# for i, j in zip(l1, l2): | |
# print(i, j) | |
# exit() | |
class Net(nn.Module): | |
def __init__(self, multiplier=1): # Multiplier is set to image dimensions after conv layers | |
super(Net, self).__init__() | |
def simple(in_channels, out_channels, kernel_size, padding=0, groups=0, inplace=True): | |
return torch.nn.Sequential( | |
torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, | |
kernel_size=kernel_size, stride=1, padding=padding, groups=groups), | |
torch.nn.LeakyReLU(negative_slope=1/5.5, inplace=inplace), | |
# torch.nn.BatchNorm2d(out_channels), | |
) | |
def linear(in_features, out_features, dropout=False): | |
return torch.nn.Sequential( | |
torch.nn.Linear(in_features=in_features, out_features=out_features), | |
torch.nn.LeakyReLU(negative_slope=1/5.5, inplace=True), | |
torch.nn.Dropout() if dropout else None, | |
) | |
self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=5, padding=2, groups=3) | |
self.conv2 = simple(in_channels=96, out_channels=128, kernel_size=5, padding=2, groups=16) | |
self.conv3_1 = simple(in_channels=128, out_channels=256, kernel_size=5, padding=2, groups=128) | |
self.conv3_2 = simple(in_channels=256, out_channels=256, kernel_size=5, padding=2, groups=256) | |
self.conv3_3 = simple(in_channels=256, out_channels=256, kernel_size=5, padding=2, groups=256) | |
self.conv4 = simple(in_channels=256, out_channels=256, kernel_size=5, padding=2, groups=256) | |
self.conv5_1 = simple(in_channels=256, out_channels=256, kernel_size=5, padding=2, groups=256) | |
self.conv5_2 = simple(in_channels=256, out_channels=512, kernel_size=5, padding=2, groups=256) | |
self.conv5_3 = simple(in_channels=512, out_channels=512, kernel_size=5, padding=2, groups=512) | |
self.conv5_4 = simple(in_channels=512, out_channels=512, kernel_size=4, padding=1, groups=512) | |
self.conv5_5 = simple(in_channels=512, out_channels=512, kernel_size=4, padding=1, groups=512, inplace=False) | |
self.conv6_1 = simple(in_channels=512, out_channels=512, kernel_size=4, padding=1, groups=512) | |
self.conv6_2 = simple(in_channels=512, out_channels=512, kernel_size=5, padding=0, groups=512) | |
self.conv6_3 = simple(in_channels=512, out_channels=512, kernel_size=5, padding=0, groups=512) | |
self.conv6_4 = simple(in_channels=512, out_channels=512, kernel_size=4, padding=0, groups=512) | |
self.conv6_5 = simple(in_channels=512, out_channels=512, kernel_size=4, padding=0, groups=512) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.fc1 = linear(512*multiplier*multiplier, 512*multiplier*multiplier, dropout=True) | |
self.fc2 = linear(512*multiplier*multiplier, 512*multiplier*multiplier, dropout=True) | |
self.fc3 = linear(512*multiplier*multiplier, 512, dropout=True) | |
self.fc4 = nn.Linear(512, 21) | |
def forward(self, x, training=True): | |
x = F.leaky_relu(self.conv1(x), negative_slope=1/5.5) | |
x = self.conv2(x) | |
x = self.conv3_1(x) | |
x = self.conv3_2(x) | |
x = self.conv3_3(x) | |
x = self.conv4(x) | |
x = self.conv5_1(x) | |
x = self.conv5_2(x) | |
x = self.conv5_3(x) | |
x = self.conv5_4(x) | |
features = self.conv5_5(x) | |
x = self.pool(self.conv5_5(x)) | |
x = self.conv6_1(x) | |
x = self.conv6_2(x) | |
x = self.conv6_3(x) | |
x = self.conv6_4(x) | |
x = self.conv6_5(x) | |
# x = F.leaky_relu(x + base_filter, negative_slope=1/5.5) | |
x = x.view(x.size(0), -1) | |
x = self.fc1(x) | |
x = self.fc2(x) | |
x = self.fc3(x) | |
x = F.sigmoid(self.fc4(x)) | |
return x, features | |
class GNet(nn.Module): | |
def __init__(self): # Multiplier is set to image dimensions after conv layers | |
super(GNet, self).__init__() | |
def linear(in_features, out_features): | |
return torch.nn.Sequential( | |
torch.nn.Linear(in_features=in_features, out_features=out_features), | |
torch.nn.LeakyReLU(negative_slope=1/5.5, inplace=False), | |
torch.nn.Dropout(0.1), | |
) | |
def simple(in_channels, out_channels, kernel_size, padding=0, groups=0): | |
return torch.nn.Sequential( | |
torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, | |
kernel_size=kernel_size, stride=1, padding=padding, groups=groups), | |
torch.nn.LeakyReLU(negative_slope=1/5.5, inplace=False), | |
torch.nn.BatchNorm2d(out_channels), | |
) | |
# self.fc1 = linear(128, 256) | |
# self.fc2 = linear(256, 512) | |
# self.fc3 = linear(512, 1024) | |
# self.fc4 = linear(1024, 2048) | |
# self.fc5 = linear(2048, 3675) | |
# self.pool = nn.MaxPool2d(2) | |
# self.conv_1 = simple(3, 105, 5, 0, 3) # 31 | |
# self.conv_2 = simple(105, 315, 5, 0, 105) # 27 | |
# self.conv_3 = simple(315, 630, 6, 0, 315) # 22 max pool | |
# self.conv_4 = simple(630, 945, 6, 0, 315) # 6 | |
# self.conv_5 = nn.Conv2d(945, 1260, 6, padding=0, groups=105) | |
# self.fc6 = linear(1260, 1890) | |
# self.fc7 = linear(1890, 2835) | |
# self.fc8 = linear(2835, 3675) | |
# self.exp_conv1 = simple(512, 512, 21, groups=512) | |
# self.exp_conv2 = simple(512, 256, 21, groups=256) | |
self.exp_conv3 = simple(1, 33, 11, groups=1) | |
self.exp_conv4 = simple(33, 66, 11, groups=33) | |
self.exp_conv5 = simple(66, 132, 11, groups=66) | |
self.exp_conv6 = simple(132, 264, 11, groups=132) | |
self.exp_conv7 = simple(264, 132, 11, groups=132) | |
self.exp_conv8 = simple(132, 66, 4, groups=66) | |
self.exp_conv9 = simple(66, 3, 4, groups=3) | |
# self.fc1 = linear(3675, 3675) | |
# self.fc2 = linear(3675, 3675) | |
# self.fc3 = linear(3675, 3675) | |
# self.conv_layer = nn.Sequential( | |
# nn.Conv2d(3, 12, 3, groups=3, padding=1), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(12), | |
# # nn.MaxPool2d(2), | |
# nn.Conv2d(12, 24, 3, groups=12, padding=1), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(24), | |
# # nn.MaxPool2d(2), | |
# nn.Conv2d(24, 32, 3, padding=1, groups=8), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(32), | |
# # nn.MaxPool2d(2), | |
# nn.Conv2d(32, 64, 3, padding=1, groups=32), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(64), | |
# nn.Conv2d(64, 128, 3, padding=1, groups=64), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(128), | |
# nn.Conv2d(128, 256, 3, padding=1, groups=128), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(256), | |
# nn.Conv2d(256, 512, 3, padding=1, groups=256), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(512), | |
# nn.Conv2d(512, 1024, 3, padding=1, groups=512), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(1024), | |
# # nn.MaxPool2d(2), | |
# nn.Conv2d(1024, 1024, 4, padding=1, groups=1024), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(1024), | |
# nn.MaxPool2d(2), # 17x17 | |
# nn.Conv2d(1024, 1024, 3, padding=0, groups=1024), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(1024), | |
# nn.Conv2d(1024, 1024, 3, padding=0, groups=1024), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(1024), | |
# nn.Conv2d(1024, 1024, 3, padding=0, groups=1024), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(1024), | |
# nn.Conv2d(1024, 1024, 3, padding=0, groups=1024), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(1024), | |
# nn.Conv2d(1024, 1024, 4, padding=1, groups=1024), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(1024), | |
# nn.Conv2d(1024, 3136, 3, padding=0, groups=64), | |
# nn.LeakyReLU(1 / 5.5), | |
# nn.BatchNorm2d(3136), | |
# nn.MaxPool2d(2), | |
# nn.Conv2d(3136, 3675, 3, padding=0, groups=49), | |
# nn.Tanh(), | |
# ) | |
def forward(self, x): | |
# x = self.input_layer(x) | |
# x = self.fc1(x) | |
# x = self.fc2(x) | |
# x = self.fc3(x) | |
# x = self.fc4(x) | |
# x = self.fc5(x) | |
# x = x.view(x.size(0), 3, 35, 35) if x.size(0) != 3675 else x.view(3, 35, 35) | |
# x = self.conv_1(x) | |
# x = self.conv_2(x) | |
# x = self.conv_3(x) | |
# x = self.pool(x) | |
# x = self.conv_4(x) | |
# x = self.conv_5(x) | |
# x = x.view(x.size(0), -1) | |
# x = self.fc6(x) | |
# x = self.fc7(x) | |
# x = F.tanh(self.fc8(x)) | |
# x = x.view(x.size(0), 3, 35, 35) if x.size(0) != 3675 else x.view(3, 35, 35) | |
# x = self.conv_1(x) | |
# x = self.conv_2(x) | |
# x = self.pool(x) | |
# x = self.conv_3(x) | |
# x = self.conv_4(x) | |
# x = self.conv_5(x) | |
# # x = x.view(x.size(0), 3, 35, 35) if x.size(0) != 3675 else x.view(3, 35, 35) | |
# x = x.view(x.size(0), -1) | |
# x = self.fc1(x) | |
# x = self.fc2(x) | |
# x = self.fc3(x) | |
# x = self.fc4(x) | |
# x = F.tanh(self.fc5(x)) | |
# x = x.view(x.size(0), 3, 35, 35) if x.size(0) != 3675 else x.view(3, 35, 35) | |
# x = self.conv_layer(x) | |
# x = x.view(x.size(0), -1) | |
# x = self.exp_conv1(x) | |
# x = self.exp_conv2(x) | |
x = self.exp_conv3(x) | |
x = self.exp_conv4(x) | |
x = self.exp_conv5(x) | |
x = self.exp_conv6(x) | |
x = self.exp_conv7(x) | |
x = self.exp_conv8(x) | |
x = F.sigmoid(self.exp_conv9(x)) | |
# x = x.view(x.size(0), 3675) | |
# x = self.fc1(x) | |
# x = self.fc2(x) | |
# x = F.tanh(self.fc3(x)) | |
# print(x.size()) | |
# x = x.view(x.size(0), 3, 35, 35) if x.size(0) != 3675 else x.view(3, 35, 35) | |
# print(x.size()) | |
# print(x.size()) | |
return x | |
class MineralsDataset(Dataset): | |
"""Face Landmarks dataset.""" | |
def __init__(self, csv_file, root_dir, transform=None): | |
""" | |
Args: | |
csv_file (string): Path to the csv file with annotations. | |
root_dir (string): Directory with all the images. | |
transform (callable, optional): Optional transform to be applied | |
on a sample. | |
""" | |
self.minerals_frame = pd.read_csv(csv_file) | |
self.root_dir = root_dir | |
self.transform = transform | |
def __len__(self): | |
return len(self.minerals_frame) | |
def __getitem__(self, idx): | |
img_name = os.path.join(self.root_dir, | |
self.minerals_frame.iloc[idx, 0]) | |
img = io.imread(img_name+'.bmp') | |
# image = color.rgb2ycbcr(image) | |
minerals = self.minerals_frame.iloc[idx, 1:].as_matrix() | |
minerals = minerals.reshape(-1, 2) | |
replacement = np.array([0 for _ in range(21)], dtype=np.float32) | |
for mineral in minerals: | |
replacement[mineral[0]] = mineral[1] | |
# for mineral in minerals: | |
# mineral[0] = (mineral[0] + 1) / 21.0 | |
minerals = replacement | |
sample = {'image': img, 'minerals': minerals} | |
if self.transform: | |
sample = self.transform(sample) | |
return sample | |
class Rescale(object): | |
"""Rescale the image in a sample to a given size. | |
Args: | |
output_size (tuple or int): Desired output size. If tuple, output is | |
matched to output_size. If int, smaller of image edges is matched | |
to output_size keeping aspect ratio the same. | |
""" | |
def __init__(self, output_size): | |
assert isinstance(output_size, (int, tuple)) | |
self.output_size = output_size | |
def __call__(self, sample): | |
img, minerals = sample['image'], sample['minerals'] | |
h, w = img.shape[:2] | |
if isinstance(self.output_size, int): | |
if h > w: | |
new_h, new_w = self.output_size * h / w, self.output_size | |
else: | |
new_h, new_w = self.output_size, self.output_size * w / h | |
else: | |
new_h, new_w = self.output_size | |
new_h, new_w = int(new_h), int(new_w) | |
img = transform.resize(img, (new_h, new_w)) | |
# h and w are swapped for minerals because for images, | |
# x and y axes are axis 1 and 0 respectively | |
# minerals = minerals * [new_w / w, new_h / h] | |
return {'image': img, 'minerals': minerals} | |
class RandomFlip(object): | |
def __call__(self, sample): | |
img, minerals = sample['image'], sample['minerals'] | |
rng = random.SystemRandom() | |
if rng.random() < 0.5: | |
img = np.flipud(img).copy() | |
if rng.random() < 0.5: | |
img = np.fliplr(img).copy() | |
return {'image': img, 'minerals': minerals} | |
class RandomRot(object): | |
def __call__(self, sample): | |
img, minerals = sample['image'], sample['minerals'] | |
rng = random.SystemRandom() | |
img = np.rot90(img, int(rng.random()*4)).copy() | |
return {'image': img, 'minerals': minerals} | |
class RandomCrop(object): | |
"""Crop randomly the image in a sample. | |
Args: | |
output_size (tuple or int): Desired output size. If int, square crop | |
is made. | |
""" | |
def __init__(self, output_size): | |
assert isinstance(output_size, (int, tuple)) | |
if isinstance(output_size, int): | |
self.output_size = (output_size, output_size) | |
else: | |
assert len(output_size) == 2 | |
self.output_size = output_size | |
def __call__(self, sample): | |
img, minerals = sample['image'], sample['minerals'] | |
h, w = img.shape[:2] | |
new_h, new_w = self.output_size | |
top = np.random.randint(0, h - new_h) | |
left = np.random.randint(0, w - new_w) | |
img = img[top: top + new_h, | |
left: left + new_w] | |
minerals = minerals - [left, top] | |
return {'image': img, 'minerals': minerals} | |
class ToTensor(object): | |
"""Convert ndarrays in sample to Tensors.""" | |
def __call__(self, sample): | |
img, minerals = sample['image'], sample['minerals'] | |
# swap color axis because | |
# numpy image: H x W x C | |
# torch image: C X H X W | |
img = img.transpose((2, 0, 1)) | |
return {'image': torch.from_numpy(img), | |
'minerals': torch.from_numpy(minerals).float()} | |
class Normalize(object): | |
def __call__(self, sample): | |
image_tensor, minerals = sample['image'], sample['minerals'] | |
norm = transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) | |
image_tensor = norm(image_tensor) | |
# transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
# std=[0.229, 0.224, 0.225]) | |
return {'image': image_tensor, 'minerals': minerals} | |
minerals_dataset = MineralsDataset(csv_file='dataset_processed.csv', | |
root_dir='Sliced/', | |
transform=transforms.Compose([ | |
Rescale((35, 35)), | |
RandomFlip(), | |
RandomRot(), | |
ToTensor(), | |
# Normalize(), | |
])) | |
batch_size = 32 | |
dataloader = DataLoader(minerals_dataset, batch_size=batch_size, | |
shuffle=True, num_workers=2, | |
pin_memory=True) | |
# Helper function to show a batch | |
def show_minerals_batch(sample_batched): | |
"""Show image with minerals for a batch of samples.""" | |
images_batch, minerals_batch = \ | |
sample_batched['image'], sample_batched['minerals'] | |
method_batch_size = len(images_batch) | |
im_size = images_batch.size(2) | |
grid = utils.make_grid(images_batch) | |
plt.imshow(grid.numpy().transpose((1, 2, 0))) | |
for n in range(method_batch_size): | |
plt.scatter(minerals_batch[n, :, 0].numpy() + n * im_size, | |
minerals_batch[n, :, 1].numpy(), | |
s=10, marker='.', c='r') | |
plt.title('Batch from dataloader') | |
def imshow(img): | |
img = img / 2 + 0.5 # unnormalize | |
npimg = img.numpy() | |
plt.imshow(np.transpose(npimg, (1, 2, 0))) | |
# if __name__ == "__main__": | |
# net = torch.load('trained.pt').eval().cuda() | |
# gen = torch.load('gen.pt').cuda() | |
# # img = io.imread("sample.bmp") | |
# samples = [ | |
# "E48D269_0001_296.bmp", | |
# "E48D269_0001_296_flip_v.bmp", | |
# "E48D269_0001_296_flip_h.bmp", | |
# "E48D269_0001_296_r_90.bmp", | |
# "E48D269_0001_296_r_180.bmp", | |
# "E48D269_0001_296_r_270.bmp", | |
# "E48D269_0065_566.bmp", | |
# ] | |
# for sample in samples: | |
# # img = color.rgb2ycbcr(io.imread(sample)) | |
# img = io.imread(sample) | |
# # img = color.rgb2ycbcr(io.imread("E48D269_0001_296.bmp")) | |
# img = transform.resize(img, (35, 35)) | |
# # img = cv2.resize(img, dsize=(35, 35), interpolation=cv2.INTER_CUBIC) | |
# img = img.transpose((2, 0, 1)) | |
# img = torch.from_numpy(img) | |
# # norm = transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) | |
# # img = norm(img) | |
# img = Variable(img).cuda() | |
# img = img.unsqueeze(0) | |
# outputs, _ = net(img.float()) | |
# print(sample) | |
# print(outputs) | |
# # print(outputs.data.max(1)[1]) | |
# # for pred in outputs.view(3, 1, 20): | |
# # print(pred.data.max(1)[1]) | |
# # output = gen(Variable(torch.from_numpy(np.array([4, 2, 11])).float()).cuda()).view(3, 35, 35) | |
# # print(output.data) | |
# # print(output.data) | |
# # utils.save_image(color.ycbcr2rgb(output.data.cpu().numpy().transpose((1, 2, 0))), 'fake_image.bmp', nrow=35) | |
# # img_out = color.ycbcr2rgb(img_as_float(output.data.cpu().numpy().transpose((1, 2, 0)))) | |
# # print(img_out) | |
# # io.imsave('fake_image.bmp', img_out) | |
# | |
# exit() | |
if __name__ == '__main__': | |
loading = True | |
net = None | |
net_2 = None | |
G = None | |
switch = True | |
if not loading: | |
if switch: | |
net = Net().cuda() | |
G = GNet().cuda() | |
else: | |
net_2 = Net() | |
net_2 = net_2.cuda() | |
else: | |
if switch: | |
net = torch.load('trained.pt').cuda() | |
G = torch.load('gen.pt').cuda() | |
else: | |
net_2 = torch.load('trained_2.pt') | |
net_2 = net_2.cuda() | |
# criterion = nn.CrossEntropyLoss() | |
criterion = None | |
criterion2 = None | |
if switch: | |
# criterion = nn.NLLLoss().cuda() | |
# criterion = nn.MSELoss().cuda() | |
criterion = nn.BCELoss().cuda() | |
criterion2 = nn.MSELoss().cuda() | |
else: | |
criterion = nn.MSELoss().cuda() | |
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | |
# optimizer = optim.Adadelta(net.parameters(), lr=0.001, momentum=0.9) | |
optimizer = optim.Adam(net.parameters(), lr=3e-4) | |
g_optimizer = optim.Adam(G.parameters(), lr=3e-4) | |
# dataloader = dataloader | |
best_batch = [] | |
fake_best_batch = [] | |
for _ in range(1000): | |
best_batch.append(1000.0) | |
fake_best_batch.append(1000.0) | |
tracker = 0 | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.fastest = True | |
batch_print = 20 | |
b = torch.zeros((batch_size, 1), dtype=torch.float, device='cuda') | |
temp_b = None | |
index = torch.tensor(20, dtype=torch.long, device='cuda') | |
temp_batch_size = 0 | |
small_batch = False | |
fake_labels = torch.zeros(batch_size, 21) | |
temp_fake_labels = None | |
for label in fake_labels: | |
label[20] = 1 | |
fake_labels = torch.tensor(fake_labels, dtype=torch.float, device='cuda') | |
temp_fake_labels = torch.zeros(17, 21) | |
for label in temp_fake_labels: | |
label[20] = 1 | |
temp_fake_labels = torch.tensor(temp_fake_labels, dtype=torch.float, device='cuda') | |
temp_b = torch.zeros((17, 1), dtype=torch.float, device='cuda') | |
random_noise = torch.randn((batch_size, 1, 91, 91), dtype=torch.float, device='cuda') | |
for epoch in range(1000): | |
running_loss = 0.0 | |
fake_running_loss = 0.0 | |
loss = 0.0 | |
loss2 = 0.0 | |
# random_noise = Variable(torch.randn(batch_size, 1, 91, 91)).cuda() | |
ticker = 0 | |
for i, data in enumerate(dataloader, 0): | |
# print(i) | |
# get the inputs | |
# print('loading data') | |
inputs, labels = data['image'], data['minerals'] | |
# print('loaded data') | |
# inputs = inputs.pin_memory() | |
# wrap them in Variable | |
# print(inputs) | |
# exit() | |
# with torch.autograd.profiler.profile(use_cuda=True) as prof: | |
# print('converting to variable and sending to GPU') | |
labels = torch.tensor(labels, dtype=torch.float, device='cuda') | |
inputs = torch.tensor(inputs, dtype=torch.float, device='cuda') | |
# print('sent to GPU') | |
# print(inputs.data) | |
# exit() | |
if labels.data.size(0) < batch_size: | |
temp_batch_size = labels.data.size(0) | |
small_batch = True | |
# forward + backward + optimize | |
# print(labels) | |
# exit() | |
# ============= Train the classifier =============# | |
outputs, real_features = net(inputs, training=True) | |
real_loss = criterion(outputs, labels) | |
fake_images = G(random_noise if not small_batch else torch.randn((temp_batch_size, 1, 91, 91), dtype=torch.float, device='cuda')) # .view(batch_size, 3, 35, 35) | |
del outputs | |
outputs, fake_features = net(fake_images, training=True) | |
feature_loss = (torch.sum(fake_features - real_features)**2) / fake_features.data.nelement() | |
fake_loss = criterion(outputs, fake_labels if not small_batch else temp_fake_labels) | |
total_loss = (real_loss * 2) + (fake_loss / 10) | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# print('before total loss backward') | |
total_loss.backward(retain_graph=True) | |
# print('after total loss backward') | |
optimizer.step() | |
# net.eval() | |
g_optimizer.zero_grad() | |
# print('before feature loss backward') | |
feature_loss.backward() | |
# print('after feature loss backward') | |
g_optimizer.step() | |
del outputs | |
steps = 2 | |
for _ in range(steps): | |
fake_images = G(random_noise if not small_batch else torch.randn((temp_batch_size, 1, 91, 91), dtype=torch.float, device='cuda')) # .view(batch_size, 3, 35, 35) | |
outputs, _ = net(fake_images, training=True) | |
img_num = 1 | |
loss_tracker = 0.0 | |
g_loss = criterion(torch.index_select(outputs, 1, index), b if not small_batch else temp_b) | |
optimizer.zero_grad() | |
g_optimizer.zero_grad() | |
g_loss.backward() | |
g_optimizer.step() | |
fake_running_loss += g_loss.data[0] | |
# print(prof) | |
# ticker += 1 | |
# if ticker == 1: | |
# exit() | |
running_loss += total_loss.data[0] | |
if small_batch: | |
small_batch = False | |
# fake_running_loss += g_loss.data[0] | |
if i % batch_print == batch_print-1: # print every n mini-batches | |
print('[%d, %5d] loss: %.15f %s, adv_loss: %.15f %s' % | |
(epoch + 1, i + 1, running_loss / batch_print, | |
'{best!}' if running_loss / batch_print < best_batch[tracker] else '', | |
(fake_running_loss / float(steps)) / batch_print, | |
'{best!}' if (fake_running_loss / float(steps)) / batch_print < fake_best_batch[tracker] else '' | |
)) | |
if running_loss / batch_print < best_batch[tracker]: | |
best_batch[tracker] = running_loss / batch_print | |
if (fake_running_loss / float(steps)) / batch_print < fake_best_batch[tracker]: | |
fake_best_batch[tracker] = (fake_running_loss / float(steps)) / batch_print | |
img_num = 1 | |
for image in np.asarray(fake_images.data): | |
# print(((image.transpose(1, 2, 0) + 1)/2)*255) | |
# print(color.ycbcr2rgb(image.transpose(1, 2, 0))) | |
# print(image) | |
image = image.transpose(1, 2, 0) | |
# for n in range(len(image)): | |
# for m in range(len(image[n])): | |
# for c in range(len(image[n, m])): | |
# image[n, m, c] = ((image[n, m, c] + 1) / 2) * 255 | |
# image = image.astype('uint8') | |
# image = color.ycbcr2rgb(image) | |
# for n in range(len(image)): | |
# for m in range(len(image[n])): | |
# for c in range(len(image[n, m])): | |
# image[n, m, c] = ((image[n, m, c] + 1) / 2) * 255 | |
# image = image.astype('uint8') | |
io.imsave('fakes/fake_image'+str(img_num)+'.bmp', image) | |
# image = cv2.imread('fakes/fake_image'+str(img_num)+'.bmp') | |
# image = cv2.cvtColor(image, cv2.COLOR_YCrCb2BGR) | |
# cv2.imwrite('fakes/fake_image'+str(img_num)+'.bmp', image) | |
img_num += 1 | |
# image = color.rgb2ycbcr(image) | |
# print(target_labels.data) | |
# print(real_outputs) | |
# print(fake_outputs) | |
running_loss = 0.0 | |
fake_running_loss = 0.0 | |
tracker += 1 | |
#print('GroundTruth: ', ' '.join('%5s' % labels[j] for j in range(4))) | |
#print('Predicted: ', ' '.join('%5s' % outputs[j] for j in range(4))) | |
tracker = 0 | |
if switch: | |
torch.save(net, 'trained.pt') | |
torch.save(G, 'gen.pt') | |
else: | |
torch.save(net_2, 'trained_2.pt') | |
torch.save(net, 'BACKUP.pt') | |
torch.save(G, 'BACKUP_gen.pt') | |
print('Saved model!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment