Skip to content

Instantly share code, notes, and snippets.

@bakirillov
Created August 10, 2019 18:42
Show Gist options
  • Save bakirillov/8fc1ae77fcc342cf96dbbe3d83b05043 to your computer and use it in GitHub Desktop.
Save bakirillov/8fc1ae77fcc342cf96dbbe3d83b05043 to your computer and use it in GitHub Desktop.
CycleGAN toy example
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torchvision as tv
import matplotlib.pyplot as plt
from IPython.display import Image
from torch.optim import Adam, SGD
from torchvision import transforms
from torch.nn import BCELoss, MSELoss
from dataset import JojoDataset, ADEDataset, ToTensor
# In[2]:
data_1 = np.random.normal(loc=0, scale=1, size=(100500, 5))
data_2 = np.random.poisson(size=(100500, 5))
# In[3]:
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.G = nn.Sequential(
nn.Linear(5, 32),
nn.ReLU(inplace=True),
nn.Linear(32, 16),
nn.ReLU(inplace=True),
nn.Linear(16, 5)
)
def forward(self, x):
return(self.G(x))
# In[4]:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.D = nn.Sequential(
nn.Linear(5, 32),
nn.ReLU(inplace=True),
nn.Linear(32, 16),
nn.ReLU(inplace=True),
nn.Linear(16, 1),
nn.Sigmoid()
)
def forward(self, x):
return(self.D(x))
# In[5]:
G_norm2pois = Generator()
D_norm2pois = Discriminator()
G_pois2norm = Generator()
D_pois2norm = Discriminator()
# In[6]:
G_n2p_optimizer = Adam(G_norm2pois.parameters())
D_n2p_optimizer = SGD(D_norm2pois.parameters(), 0.002)
G_p2n_optimizer = Adam(G_pois2norm.parameters())
D_p2n_optimizer = SGD(D_pois2norm.parameters(), 0.002)
# In[7]:
def iterate_minibatches(X, y, batchsize):
indices = np.random.permutation(np.arange(len(X)))
for start in range(0, len(indices), batchsize):
ix = indices[start: start + batchsize]
yield X[ix], y[ix]
# In[8]:
BATCH_SIZE = 256
# In[9]:
loss_D_n2p = BCELoss()
loss_G_n2p = BCELoss()
loss_D_p2n = BCELoss()
loss_G_p2n = BCELoss()
# In[10]:
loss_ccl_forward = MSELoss()
loss_ccl_backward = MSELoss()
# In[11]:
def epoch():
errors_D_n2p = []
errors_D_n2p_fake = []
errors_G_n2p = []
errors_D_p2n = []
errors_D_p2n_fake = []
errors_G_p2n = []
#i = 0
for b_norm, b_pois in iterate_minibatches(data_1, data_2, BATCH_SIZE):
#i += 1
D_norm2pois.zero_grad()
b_norm_t = torch.from_numpy(b_norm).type(torch.FloatTensor)
b_pois_t = torch.from_numpy(b_pois).type(torch.FloatTensor)
b_true_pois = D_norm2pois(b_pois_t)
error_D_n2p = loss_D_n2p(b_true_pois, torch.ones(b_true_pois.shape[0]).type(torch.FloatTensor))
error_D_n2p.backward()
b_fake_pois = D_norm2pois(G_norm2pois(b_norm_t))
error_D_n2p_fake = loss_D_n2p(b_fake_pois, torch.zeros(b_fake_pois.shape[0]).type(torch.FloatTensor))
error_D_n2p_fake.backward()
D_n2p_optimizer.step()
errors_D_n2p.append(error_D_n2p.data.numpy())
errors_D_n2p_fake.append(error_D_n2p_fake.data.numpy())
G_norm2pois.zero_grad()
b_fake_pois = D_norm2pois(G_norm2pois(b_norm_t))
error_G_n2p = loss_G_n2p(b_fake_pois, torch.ones(b_fake_pois.shape[0]).type(torch.FloatTensor))
error_G_n2p.backward()
G_n2p_optimizer.step()
errors_G_n2p.append(error_G_n2p.data.numpy())
D_pois2norm.zero_grad()
b_norm_t = torch.from_numpy(b_norm).type(torch.FloatTensor)
b_pois_t = torch.from_numpy(b_pois).type(torch.FloatTensor)
b_true_norm = D_pois2norm(b_norm_t)
error_D_p2n = loss_D_p2n(b_true_norm, torch.ones(b_true_norm.shape[0]).type(torch.FloatTensor))
error_D_p2n.backward()
b_fake_norm = D_pois2norm(G_pois2norm(b_pois_t))
error_D_p2n_fake = loss_D_p2n(b_fake_norm, torch.zeros(b_fake_norm.shape[0]).type(torch.FloatTensor))
error_D_p2n_fake.backward()
D_p2n_optimizer.step()
errors_D_p2n.append(error_D_p2n.data.numpy())
errors_D_p2n_fake.append(error_D_p2n_fake.data.numpy())
G_pois2norm.zero_grad()
b_fake_norm = D_pois2norm(G_pois2norm(b_pois_t))
error_G_p2n = loss_G_p2n(b_fake_norm, torch.ones(b_fake_norm.shape[0]).type(torch.FloatTensor))
error_G_p2n.backward()
ccl = loss_ccl_forward(G_norm2pois(b_norm_t), b_pois_t)
ccl += loss_ccl_backward(G_pois2norm(b_pois_t), b_norm_t)
ccl.backward()
G_p2n_optimizer.step()
errors_G_p2n.append(error_G_p2n.data.numpy())
return(errors_D_p2n, errors_D_n2p, errors_D_p2n_fake, error_D_n2p_fake, error_G_p2n, errors_G_n2p)
# In[12]:
history = []
# In[13]:
for a in tqdm(np.arange(100)):
history.append(epoch())
# In[24]:
history[0]
# In[25]:
plt.hist(
G(torch.from_numpy(np.random.uniform(0,1,(256, 5))).type(torch.FloatTensor)).data.numpy()
)
plt.hist(data[np.random.choice(np.arange(data.shape[0]), 256)])
plt.show()
# In[26]:
norm = np.random.normal(0, 1, (256, 5))
# In[28]:
plt.hist(norm)
plt.show()
# In[31]:
fake_pois = G_norm2pois(torch.from_numpy(norm).type(torch.FloatTensor)).data.numpy()
# In[32]:
plt.hist(fake_pois)
plt.show()
# In[35]:
pois = np.random.poisson(size=(256, 5))
plt.hist(pois)
plt.show()
# In[36]:
fake_norm = G_pois2norm(torch.from_numpy(pois).type(torch.FloatTensor)).data.numpy()
plt.hist(fake_norm)
plt.show()
# In[ ]:
# In[ ]:
plt.hist(
G(torch.from_numpy(np.random.uniform(0,1,(256, 5))).type(torch.FloatTensor)).data.numpy()
)
plt.hist(data[np.random.choice(np.arange(data.shape[0]), 256)])
plt.show()
# In[ ]:
u = G(torch.from_numpy(np.random.uniform(0,1,(256, 5))).type(torch.FloatTensor)).data.numpy()
# In[ ]:
plt.hist(
u
)
plt.hist(data[np.random.choice(np.arange(data.shape[0]), 256)])
plt.show()
# In[ ]:
np.mean(u, 0)
# In[ ]:
np.std(u, 0)
# In[ ]:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment