Skip to content

Instantly share code, notes, and snippets.

@bakirillov
Created August 10, 2019 17:41
Show Gist options
  • Save bakirillov/9753ad8605d577bc335ce8563a4c2d3f to your computer and use it in GitHub Desktop.
Save bakirillov/9753ad8605d577bc335ce8563a4c2d3f to your computer and use it in GitHub Desktop.
Simple GAN
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torchvision as tv
from torch.nn import BCELoss
import matplotlib.pyplot as plt
from IPython.display import Image
from torch.optim import Adam, SGD
from torchvision import transforms
from dataset import JojoDataset, ADEDataset, ToTensor
# In[37]:
data = np.random.normal(loc=0, scale=1, size=(100500, 5))
# In[38]:
data.shape
# In[4]:
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.G = nn.Sequential(
nn.Linear(5, 32),
nn.ReLU(inplace=True),
nn.Linear(32, 16),
nn.ReLU(inplace=True),
nn.Linear(16, 5)
)
def forward(self, x):
return(self.G(x))
# In[5]:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.D = nn.Sequential(
nn.Linear(5, 32),
nn.ReLU(inplace=True),
nn.Linear(32, 16),
nn.ReLU(inplace=True),
nn.Linear(16, 1),
nn.Sigmoid()
)
def forward(self, x):
return(self.D(x))
# In[39]:
G = Generator()
D = Discriminator()
# In[40]:
G_optimizer = Adam(G.parameters())
D_optimizer = SGD(D.parameters(), 0.002)
# In[8]:
def iterate_minibatches(X, y, batchsize):
indices = np.random.permutation(np.arange(len(X)))
for start in range(0, len(indices), batchsize):
ix = indices[start: start + batchsize]
yield X[ix], y[ix]
# In[9]:
BATCH_SIZE = 256
# In[41]:
loss_D = BCELoss()
loss_G = BCELoss()
# In[42]:
def epoch():
errors_D = []
errors_D_fake = []
errors_G = []
i = 0
for batch_X, _ in iterate_minibatches(data, np.array([0]*data.shape[0]), BATCH_SIZE):
i += 1
D.zero_grad()
batch_Xt = torch.from_numpy(batch_X).type(torch.FloatTensor)
batch_Y = D(batch_Xt)
error_D = loss_D(batch_Y, torch.ones(batch_Y.shape[0]).type(torch.FloatTensor))
error_D.backward()
fake_It = torch.from_numpy(
np.random.uniform(0, 1, size=(BATCH_SIZE, 5))
).type(torch.FloatTensor)
fake_Xt = G(fake_It)
fake_Y = D(fake_Xt)
error_D_fake = loss_D(fake_Y, torch.zeros(BATCH_SIZE).type(torch.FloatTensor))
error_D_fake.backward()
D_optimizer.step()
errors_D.append(error_D.data.numpy())
errors_D_fake.append(error_D_fake.data.numpy())
for a in range(i):
G.zero_grad()
fake_It = torch.from_numpy(
np.random.uniform(0, 1, size=(BATCH_SIZE, 5))
).type(torch.FloatTensor)
fake_Xt = G(fake_It)
fake_Y = D(fake_Xt)
error_G = loss_G(fake_Y, torch.ones(fake_Y.shape[0]).type(torch.FloatTensor))
error_G.backward()
G_optimizer.step()
errors_G.append(error_G.data.numpy())
return(errors_D, errors_D_fake, errors_G)
# In[43]:
history = []
# In[44]:
for a in tqdm(np.arange(100)):
history.append(epoch())
# In[45]:
plt.plot(sum(history[0], []))
plt.plot(sum(history[1], []))
plt.plot(sum(history[2], []))
plt.xlabel("#batch")
plt.ylabel("Loss")
plt.show()
# In[46]:
plt.hist(
G(torch.from_numpy(np.random.uniform(0,1,(256, 5))).type(torch.FloatTensor)).data.numpy()
)
plt.hist(data[np.random.choice(np.arange(data.shape[0]), 256)])
plt.show()
# In[47]:
plt.hist(
G(torch.from_numpy(np.random.uniform(0,1,(256, 5))).type(torch.FloatTensor)).data.numpy()
)
plt.hist(data[np.random.choice(np.arange(data.shape[0]), 256)])
plt.show()
# In[48]:
u = G(torch.from_numpy(np.random.uniform(0,1,(256, 5))).type(torch.FloatTensor)).data.numpy()
# In[49]:
plt.hist(
u
)
plt.hist(data[np.random.choice(np.arange(data.shape[0]), 256)])
plt.show()
# In[50]:
np.mean(u, 0)
# In[51]:
np.std(u, 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment