Skip to content

Instantly share code, notes, and snippets.

@Sieyk
Last active April 27, 2018 01:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Sieyk/d3737b12d38ec41e0792f6eae7a57ffc to your computer and use it in GitHub Desktop.
Save Sieyk/d3737b12d38ec41e0792f6eae7a57ffc to your computer and use it in GitHub Desktop.
Source code for pytorch
from __future__ import print_function, division
import os
import torch
import math
import pandas as pd
from skimage import io, transform, color, img_as_float
# from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# from torchvision.transforms import functional
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import cv2
import torch.optim as optim
import random
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
OMP_NUM_THREADS = 11
plt.ion() # interactive mode
# minerals_frame = pd.read_csv('dataset_processed.csv')
#
# n = 1
# _img_name = minerals_frame.iloc[n, 0]
# _minerals = minerals_frame.iloc[n, 1:].as_matrix()
# _minerals = _minerals.astype('float').reshape(-1, 2)
#
# print('Image name: {}'.format(_img_name))
# print('Minerals shape: {}'.format(minerals.shape))
# print('First 4 Minerals: {}'.format(minerals[:4]))
# l1 = [1, 2, 3]
# l2 = [4, 5, 6]
#
# for i, j in zip(l1, l2):
# print(i, j)
# exit()
class Net(nn.Module):
def __init__(self, multiplier=1): # Multiplier is set to image dimensions after conv layers
super(Net, self).__init__()
def simple(in_channels, out_channels, kernel_size, padding=0, groups=0, inplace=True):
return torch.nn.Sequential(
torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=1, padding=padding, groups=groups),
torch.nn.LeakyReLU(negative_slope=1/5.5, inplace=inplace),
# torch.nn.BatchNorm2d(out_channels),
)
def linear(in_features, out_features, dropout=False):
return torch.nn.Sequential(
torch.nn.Linear(in_features=in_features, out_features=out_features),
torch.nn.LeakyReLU(negative_slope=1/5.5, inplace=True),
torch.nn.Dropout() if dropout else None,
)
self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=5, padding=2, groups=3)
self.conv2 = simple(in_channels=96, out_channels=128, kernel_size=5, padding=2, groups=16)
self.conv3_1 = simple(in_channels=128, out_channels=256, kernel_size=5, padding=2, groups=128)
self.conv3_2 = simple(in_channels=256, out_channels=256, kernel_size=5, padding=2, groups=256)
self.conv3_3 = simple(in_channels=256, out_channels=256, kernel_size=5, padding=2, groups=256)
self.conv4 = simple(in_channels=256, out_channels=256, kernel_size=5, padding=2, groups=256)
self.conv5_1 = simple(in_channels=256, out_channels=256, kernel_size=5, padding=2, groups=256)
self.conv5_2 = simple(in_channels=256, out_channels=512, kernel_size=5, padding=2, groups=256)
self.conv5_3 = simple(in_channels=512, out_channels=512, kernel_size=5, padding=2, groups=512)
self.conv5_4 = simple(in_channels=512, out_channels=512, kernel_size=4, padding=1, groups=512)
self.conv5_5 = simple(in_channels=512, out_channels=512, kernel_size=4, padding=1, groups=512, inplace=False)
self.conv6_1 = simple(in_channels=512, out_channels=512, kernel_size=4, padding=1, groups=512)
self.conv6_2 = simple(in_channels=512, out_channels=512, kernel_size=5, padding=0, groups=512)
self.conv6_3 = simple(in_channels=512, out_channels=512, kernel_size=5, padding=0, groups=512)
self.conv6_4 = simple(in_channels=512, out_channels=512, kernel_size=4, padding=0, groups=512)
self.conv6_5 = simple(in_channels=512, out_channels=512, kernel_size=4, padding=0, groups=512)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = linear(512*multiplier*multiplier, 512*multiplier*multiplier, dropout=True)
self.fc2 = linear(512*multiplier*multiplier, 512*multiplier*multiplier, dropout=True)
self.fc3 = linear(512*multiplier*multiplier, 512, dropout=True)
self.fc4 = nn.Linear(512, 21)
def forward(self, x, training=True):
x = F.leaky_relu(self.conv1(x), negative_slope=1/5.5)
x = self.conv2(x)
x = self.conv3_1(x)
x = self.conv3_2(x)
x = self.conv3_3(x)
x = self.conv4(x)
x = self.conv5_1(x)
x = self.conv5_2(x)
x = self.conv5_3(x)
x = self.conv5_4(x)
features = self.conv5_5(x)
x = self.pool(self.conv5_5(x))
x = self.conv6_1(x)
x = self.conv6_2(x)
x = self.conv6_3(x)
x = self.conv6_4(x)
x = self.conv6_5(x)
# x = F.leaky_relu(x + base_filter, negative_slope=1/5.5)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = F.sigmoid(self.fc4(x))
return x, features
class GNet(nn.Module):
def __init__(self): # Multiplier is set to image dimensions after conv layers
super(GNet, self).__init__()
def linear(in_features, out_features):
return torch.nn.Sequential(
torch.nn.Linear(in_features=in_features, out_features=out_features),
torch.nn.LeakyReLU(negative_slope=1/5.5, inplace=False),
torch.nn.Dropout(0.1),
)
def simple(in_channels, out_channels, kernel_size, padding=0, groups=0):
return torch.nn.Sequential(
torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=1, padding=padding, groups=groups),
torch.nn.LeakyReLU(negative_slope=1/5.5, inplace=False),
torch.nn.BatchNorm2d(out_channels),
)
# self.fc1 = linear(128, 256)
# self.fc2 = linear(256, 512)
# self.fc3 = linear(512, 1024)
# self.fc4 = linear(1024, 2048)
# self.fc5 = linear(2048, 3675)
# self.pool = nn.MaxPool2d(2)
# self.conv_1 = simple(3, 105, 5, 0, 3) # 31
# self.conv_2 = simple(105, 315, 5, 0, 105) # 27
# self.conv_3 = simple(315, 630, 6, 0, 315) # 22 max pool
# self.conv_4 = simple(630, 945, 6, 0, 315) # 6
# self.conv_5 = nn.Conv2d(945, 1260, 6, padding=0, groups=105)
# self.fc6 = linear(1260, 1890)
# self.fc7 = linear(1890, 2835)
# self.fc8 = linear(2835, 3675)
# self.exp_conv1 = simple(512, 512, 21, groups=512)
# self.exp_conv2 = simple(512, 256, 21, groups=256)
self.exp_conv3 = simple(1, 33, 11, groups=1)
self.exp_conv4 = simple(33, 66, 11, groups=33)
self.exp_conv5 = simple(66, 132, 11, groups=66)
self.exp_conv6 = simple(132, 264, 11, groups=132)
self.exp_conv7 = simple(264, 132, 11, groups=132)
self.exp_conv8 = simple(132, 66, 4, groups=66)
self.exp_conv9 = simple(66, 3, 4, groups=3)
# self.fc1 = linear(3675, 3675)
# self.fc2 = linear(3675, 3675)
# self.fc3 = linear(3675, 3675)
# self.conv_layer = nn.Sequential(
# nn.Conv2d(3, 12, 3, groups=3, padding=1),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(12),
# # nn.MaxPool2d(2),
# nn.Conv2d(12, 24, 3, groups=12, padding=1),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(24),
# # nn.MaxPool2d(2),
# nn.Conv2d(24, 32, 3, padding=1, groups=8),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(32),
# # nn.MaxPool2d(2),
# nn.Conv2d(32, 64, 3, padding=1, groups=32),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(64),
# nn.Conv2d(64, 128, 3, padding=1, groups=64),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(128),
# nn.Conv2d(128, 256, 3, padding=1, groups=128),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(256),
# nn.Conv2d(256, 512, 3, padding=1, groups=256),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(512),
# nn.Conv2d(512, 1024, 3, padding=1, groups=512),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(1024),
# # nn.MaxPool2d(2),
# nn.Conv2d(1024, 1024, 4, padding=1, groups=1024),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(1024),
# nn.MaxPool2d(2), # 17x17
# nn.Conv2d(1024, 1024, 3, padding=0, groups=1024),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(1024),
# nn.Conv2d(1024, 1024, 3, padding=0, groups=1024),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(1024),
# nn.Conv2d(1024, 1024, 3, padding=0, groups=1024),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(1024),
# nn.Conv2d(1024, 1024, 3, padding=0, groups=1024),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(1024),
# nn.Conv2d(1024, 1024, 4, padding=1, groups=1024),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(1024),
# nn.Conv2d(1024, 3136, 3, padding=0, groups=64),
# nn.LeakyReLU(1 / 5.5),
# nn.BatchNorm2d(3136),
# nn.MaxPool2d(2),
# nn.Conv2d(3136, 3675, 3, padding=0, groups=49),
# nn.Tanh(),
# )
def forward(self, x):
# x = self.input_layer(x)
# x = self.fc1(x)
# x = self.fc2(x)
# x = self.fc3(x)
# x = self.fc4(x)
# x = self.fc5(x)
# x = x.view(x.size(0), 3, 35, 35) if x.size(0) != 3675 else x.view(3, 35, 35)
# x = self.conv_1(x)
# x = self.conv_2(x)
# x = self.conv_3(x)
# x = self.pool(x)
# x = self.conv_4(x)
# x = self.conv_5(x)
# x = x.view(x.size(0), -1)
# x = self.fc6(x)
# x = self.fc7(x)
# x = F.tanh(self.fc8(x))
# x = x.view(x.size(0), 3, 35, 35) if x.size(0) != 3675 else x.view(3, 35, 35)
# x = self.conv_1(x)
# x = self.conv_2(x)
# x = self.pool(x)
# x = self.conv_3(x)
# x = self.conv_4(x)
# x = self.conv_5(x)
# # x = x.view(x.size(0), 3, 35, 35) if x.size(0) != 3675 else x.view(3, 35, 35)
# x = x.view(x.size(0), -1)
# x = self.fc1(x)
# x = self.fc2(x)
# x = self.fc3(x)
# x = self.fc4(x)
# x = F.tanh(self.fc5(x))
# x = x.view(x.size(0), 3, 35, 35) if x.size(0) != 3675 else x.view(3, 35, 35)
# x = self.conv_layer(x)
# x = x.view(x.size(0), -1)
# x = self.exp_conv1(x)
# x = self.exp_conv2(x)
x = self.exp_conv3(x)
x = self.exp_conv4(x)
x = self.exp_conv5(x)
x = self.exp_conv6(x)
x = self.exp_conv7(x)
x = self.exp_conv8(x)
x = F.sigmoid(self.exp_conv9(x))
# x = x.view(x.size(0), 3675)
# x = self.fc1(x)
# x = self.fc2(x)
# x = F.tanh(self.fc3(x))
# print(x.size())
# x = x.view(x.size(0), 3, 35, 35) if x.size(0) != 3675 else x.view(3, 35, 35)
# print(x.size())
# print(x.size())
return x
class MineralsDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.minerals_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.minerals_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.minerals_frame.iloc[idx, 0])
img = io.imread(img_name+'.bmp')
# image = color.rgb2ycbcr(image)
minerals = self.minerals_frame.iloc[idx, 1:].as_matrix()
minerals = minerals.reshape(-1, 2)
replacement = np.array([0 for _ in range(21)], dtype=np.float32)
for mineral in minerals:
replacement[mineral[0]] = mineral[1]
# for mineral in minerals:
# mineral[0] = (mineral[0] + 1) / 21.0
minerals = replacement
sample = {'image': img, 'minerals': minerals}
if self.transform:
sample = self.transform(sample)
return sample
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
img, minerals = sample['image'], sample['minerals']
h, w = img.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(img, (new_h, new_w))
# h and w are swapped for minerals because for images,
# x and y axes are axis 1 and 0 respectively
# minerals = minerals * [new_w / w, new_h / h]
return {'image': img, 'minerals': minerals}
class RandomFlip(object):
def __call__(self, sample):
img, minerals = sample['image'], sample['minerals']
rng = random.SystemRandom()
if rng.random() < 0.5:
img = np.flipud(img).copy()
if rng.random() < 0.5:
img = np.fliplr(img).copy()
return {'image': img, 'minerals': minerals}
class RandomRot(object):
def __call__(self, sample):
img, minerals = sample['image'], sample['minerals']
rng = random.SystemRandom()
img = np.rot90(img, int(rng.random()*4)).copy()
return {'image': img, 'minerals': minerals}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
img, minerals = sample['image'], sample['minerals']
h, w = img.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
img = img[top: top + new_h,
left: left + new_w]
minerals = minerals - [left, top]
return {'image': img, 'minerals': minerals}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
img, minerals = sample['image'], sample['minerals']
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
img = img.transpose((2, 0, 1))
return {'image': torch.from_numpy(img),
'minerals': torch.from_numpy(minerals).float()}
class Normalize(object):
def __call__(self, sample):
image_tensor, minerals = sample['image'], sample['minerals']
norm = transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
image_tensor = norm(image_tensor)
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
return {'image': image_tensor, 'minerals': minerals}
minerals_dataset = MineralsDataset(csv_file='dataset_processed.csv',
root_dir='Sliced/',
transform=transforms.Compose([
Rescale((35, 35)),
RandomFlip(),
RandomRot(),
ToTensor(),
# Normalize(),
]))
batch_size = 32
dataloader = DataLoader(minerals_dataset, batch_size=batch_size,
shuffle=True, num_workers=2,
pin_memory=True)
# Helper function to show a batch
def show_minerals_batch(sample_batched):
"""Show image with minerals for a batch of samples."""
images_batch, minerals_batch = \
sample_batched['image'], sample_batched['minerals']
method_batch_size = len(images_batch)
im_size = images_batch.size(2)
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for n in range(method_batch_size):
plt.scatter(minerals_batch[n, :, 0].numpy() + n * im_size,
minerals_batch[n, :, 1].numpy(),
s=10, marker='.', c='r')
plt.title('Batch from dataloader')
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# if __name__ == "__main__":
# net = torch.load('trained.pt').eval().cuda()
# gen = torch.load('gen.pt').cuda()
# # img = io.imread("sample.bmp")
# samples = [
# "E48D269_0001_296.bmp",
# "E48D269_0001_296_flip_v.bmp",
# "E48D269_0001_296_flip_h.bmp",
# "E48D269_0001_296_r_90.bmp",
# "E48D269_0001_296_r_180.bmp",
# "E48D269_0001_296_r_270.bmp",
# "E48D269_0065_566.bmp",
# ]
# for sample in samples:
# # img = color.rgb2ycbcr(io.imread(sample))
# img = io.imread(sample)
# # img = color.rgb2ycbcr(io.imread("E48D269_0001_296.bmp"))
# img = transform.resize(img, (35, 35))
# # img = cv2.resize(img, dsize=(35, 35), interpolation=cv2.INTER_CUBIC)
# img = img.transpose((2, 0, 1))
# img = torch.from_numpy(img)
# # norm = transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
# # img = norm(img)
# img = Variable(img).cuda()
# img = img.unsqueeze(0)
# outputs, _ = net(img.float())
# print(sample)
# print(outputs)
# # print(outputs.data.max(1)[1])
# # for pred in outputs.view(3, 1, 20):
# # print(pred.data.max(1)[1])
# # output = gen(Variable(torch.from_numpy(np.array([4, 2, 11])).float()).cuda()).view(3, 35, 35)
# # print(output.data)
# # print(output.data)
# # utils.save_image(color.ycbcr2rgb(output.data.cpu().numpy().transpose((1, 2, 0))), 'fake_image.bmp', nrow=35)
# # img_out = color.ycbcr2rgb(img_as_float(output.data.cpu().numpy().transpose((1, 2, 0))))
# # print(img_out)
# # io.imsave('fake_image.bmp', img_out)
#
# exit()
if __name__ == '__main__':
loading = True
net = None
net_2 = None
G = None
switch = True
if not loading:
if switch:
net = Net().cuda()
G = GNet().cuda()
else:
net_2 = Net()
net_2 = net_2.cuda()
else:
if switch:
net = torch.load('trained.pt').cuda()
G = torch.load('gen.pt').cuda()
else:
net_2 = torch.load('trained_2.pt')
net_2 = net_2.cuda()
# criterion = nn.CrossEntropyLoss()
criterion = None
criterion2 = None
if switch:
# criterion = nn.NLLLoss().cuda()
# criterion = nn.MSELoss().cuda()
criterion = nn.BCELoss().cuda()
criterion2 = nn.MSELoss().cuda()
else:
criterion = nn.MSELoss().cuda()
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# optimizer = optim.Adadelta(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=3e-4)
g_optimizer = optim.Adam(G.parameters(), lr=3e-4)
# dataloader = dataloader
best_batch = []
fake_best_batch = []
for _ in range(1000):
best_batch.append(1000.0)
fake_best_batch.append(1000.0)
tracker = 0
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.fastest = True
batch_print = 20
b = torch.zeros((batch_size, 1), dtype=torch.float, device='cuda')
temp_b = None
index = torch.tensor(20, dtype=torch.long, device='cuda')
temp_batch_size = 0
small_batch = False
fake_labels = torch.zeros(batch_size, 21)
temp_fake_labels = None
for label in fake_labels:
label[20] = 1
fake_labels = torch.tensor(fake_labels, dtype=torch.float, device='cuda')
temp_fake_labels = torch.zeros(17, 21)
for label in temp_fake_labels:
label[20] = 1
temp_fake_labels = torch.tensor(temp_fake_labels, dtype=torch.float, device='cuda')
temp_b = torch.zeros((17, 1), dtype=torch.float, device='cuda')
random_noise = torch.randn((batch_size, 1, 91, 91), dtype=torch.float, device='cuda')
for epoch in range(1000):
running_loss = 0.0
fake_running_loss = 0.0
loss = 0.0
loss2 = 0.0
# random_noise = Variable(torch.randn(batch_size, 1, 91, 91)).cuda()
ticker = 0
for i, data in enumerate(dataloader, 0):
# print(i)
# get the inputs
# print('loading data')
inputs, labels = data['image'], data['minerals']
# print('loaded data')
# inputs = inputs.pin_memory()
# wrap them in Variable
# print(inputs)
# exit()
# with torch.autograd.profiler.profile(use_cuda=True) as prof:
# print('converting to variable and sending to GPU')
labels = torch.tensor(labels, dtype=torch.float, device='cuda')
inputs = torch.tensor(inputs, dtype=torch.float, device='cuda')
# print('sent to GPU')
# print(inputs.data)
# exit()
if labels.data.size(0) < batch_size:
temp_batch_size = labels.data.size(0)
small_batch = True
# forward + backward + optimize
# print(labels)
# exit()
# ============= Train the classifier =============#
outputs, real_features = net(inputs, training=True)
real_loss = criterion(outputs, labels)
fake_images = G(random_noise if not small_batch else torch.randn((temp_batch_size, 1, 91, 91), dtype=torch.float, device='cuda')) # .view(batch_size, 3, 35, 35)
del outputs
outputs, fake_features = net(fake_images, training=True)
feature_loss = (torch.sum(fake_features - real_features)**2) / fake_features.data.nelement()
fake_loss = criterion(outputs, fake_labels if not small_batch else temp_fake_labels)
total_loss = (real_loss * 2) + (fake_loss / 10)
# zero the parameter gradients
optimizer.zero_grad()
# print('before total loss backward')
total_loss.backward(retain_graph=True)
# print('after total loss backward')
optimizer.step()
# net.eval()
g_optimizer.zero_grad()
# print('before feature loss backward')
feature_loss.backward()
# print('after feature loss backward')
g_optimizer.step()
del outputs
steps = 2
for _ in range(steps):
fake_images = G(random_noise if not small_batch else torch.randn((temp_batch_size, 1, 91, 91), dtype=torch.float, device='cuda')) # .view(batch_size, 3, 35, 35)
outputs, _ = net(fake_images, training=True)
img_num = 1
loss_tracker = 0.0
g_loss = criterion(torch.index_select(outputs, 1, index), b if not small_batch else temp_b)
optimizer.zero_grad()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
fake_running_loss += g_loss.data[0]
# print(prof)
# ticker += 1
# if ticker == 1:
# exit()
running_loss += total_loss.data[0]
if small_batch:
small_batch = False
# fake_running_loss += g_loss.data[0]
if i % batch_print == batch_print-1: # print every n mini-batches
print('[%d, %5d] loss: %.15f %s, adv_loss: %.15f %s' %
(epoch + 1, i + 1, running_loss / batch_print,
'{best!}' if running_loss / batch_print < best_batch[tracker] else '',
(fake_running_loss / float(steps)) / batch_print,
'{best!}' if (fake_running_loss / float(steps)) / batch_print < fake_best_batch[tracker] else ''
))
if running_loss / batch_print < best_batch[tracker]:
best_batch[tracker] = running_loss / batch_print
if (fake_running_loss / float(steps)) / batch_print < fake_best_batch[tracker]:
fake_best_batch[tracker] = (fake_running_loss / float(steps)) / batch_print
img_num = 1
for image in np.asarray(fake_images.data):
# print(((image.transpose(1, 2, 0) + 1)/2)*255)
# print(color.ycbcr2rgb(image.transpose(1, 2, 0)))
# print(image)
image = image.transpose(1, 2, 0)
# for n in range(len(image)):
# for m in range(len(image[n])):
# for c in range(len(image[n, m])):
# image[n, m, c] = ((image[n, m, c] + 1) / 2) * 255
# image = image.astype('uint8')
# image = color.ycbcr2rgb(image)
# for n in range(len(image)):
# for m in range(len(image[n])):
# for c in range(len(image[n, m])):
# image[n, m, c] = ((image[n, m, c] + 1) / 2) * 255
# image = image.astype('uint8')
io.imsave('fakes/fake_image'+str(img_num)+'.bmp', image)
# image = cv2.imread('fakes/fake_image'+str(img_num)+'.bmp')
# image = cv2.cvtColor(image, cv2.COLOR_YCrCb2BGR)
# cv2.imwrite('fakes/fake_image'+str(img_num)+'.bmp', image)
img_num += 1
# image = color.rgb2ycbcr(image)
# print(target_labels.data)
# print(real_outputs)
# print(fake_outputs)
running_loss = 0.0
fake_running_loss = 0.0
tracker += 1
#print('GroundTruth: ', ' '.join('%5s' % labels[j] for j in range(4)))
#print('Predicted: ', ' '.join('%5s' % outputs[j] for j in range(4)))
tracker = 0
if switch:
torch.save(net, 'trained.pt')
torch.save(G, 'gen.pt')
else:
torch.save(net_2, 'trained_2.pt')
torch.save(net, 'BACKUP.pt')
torch.save(G, 'BACKUP_gen.pt')
print('Saved model!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment