Skip to content

Instantly share code, notes, and snippets.

@chmodsss
Created January 3, 2020 22:37
Show Gist options
  • Save chmodsss/018529e2ec68f8971a33235570f6f3d8 to your computer and use it in GitHub Desktop.
Save chmodsss/018529e2ec68f8971a33235570f6f3d8 to your computer and use it in GitHub Desktop.
Generative Adversarial network using pytorch
import torch
from torchvision.datasets import MNIST
from torch import nn
from torchvision import transforms
import torch.optim as optim
from matplotlib import pyplot as plt
import matplotlib.cm as cm
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda:0")
is_cuda = torch.cuda.is_available()
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((.5,),(.5,)), torch.flatten])
traindata = MNIST(root='./data', transform=trans, train=True, download=True)
batches = 6000
#traindata = [(data,label) for (data,label) in traindata if label==0]
#traindata = traindata[:5000]
#batches = 500
trainloader = torch.utils.data.DataLoader(traindata, batch_size=batches, shuffle=True)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
ip_emb = 784
emb1 = 256
emb2 = 128
out_emb = 1
self.layer1 = nn.Sequential(
nn.Linear(ip_emb, emb1),
nn.LeakyReLU(0.2),
nn.Dropout(0.3))
self.layer2 = nn.Sequential(
nn.Linear(emb1, emb2),
nn.LeakyReLU(0.2),
nn.Dropout(0.3))
self.layer_out = nn.Sequential(
nn.Linear(emb2, out_emb),
nn.Sigmoid())
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer_out(x)
return x
class Generator(nn.Module):
def __init__(self):
super().__init__()
ip_emb = 128
emb1 = 256
emb2 = 512
emb3 = 1024
out_emb = 784
self.layer1 = nn.Sequential(
nn.Linear(ip_emb, emb1),
nn.LeakyReLU(0.2))
self.layer2 = nn.Sequential(
nn.Linear(emb1, emb2),
nn.LeakyReLU(0.2))
self.layer3 = nn.Sequential(
nn.Linear(emb2, emb3),
nn.LeakyReLU(0.2))
self.layer_out = nn.Sequential(
nn.Linear(emb3, out_emb),
nn.Tanh())
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer_out(x)
return x
discriminator = Discriminator()
generator = Generator()
if is_cuda:
discriminator.to(device)
generator.to(device)
criterion = nn.BCELoss()
discrim_optim = optim.Adam(discriminator.parameters(), lr= 0.0002)
generat_optim = optim.Adam(generator.parameters(), lr=0.0002)
def noise(x,y):
if is_cuda:
return torch.randn(x,y).cuda()
return torch.randn(x,y)
def get_nearones(x):
if is_cuda:
return torch.ones(x,1).cuda()-0.01
return torch.ones(x,1)-0.01
def get_nearzeros(x):
if is_cuda:
return torch.zeros(x,1).cuda()+0.01
return torch.zeros(x,1)+0.01
def plotimage(is_cuda):
if is_cuda:
plt.imshow(generator(noise(1, 128)).cpu().detach().view(28,28).numpy(), cmap=cm.gray)
else:
plt.imshow(generator(noise(1, 128)).detach().view(28,28).numpy(), cmap=cm.gray)
plt.show()
derrors = []
gerrors = []
dxcumul = []
gxcumul = []
for epoch in range(2000):
dx = 0
gx = 0
derr = 0
gerr = 0
for pos_samples in trainloader:
# Training Discriminator network
discrim_optim.zero_grad()
pos_sample = pos_samples[0].cuda() if is_cuda else pos_samples[0]
pos_predicted = discriminator(pos_sample)
pos_error = criterion(pos_predicted, get_nearones(batches))
neg_samples = generator(noise(batches, 128))
neg_predicted = discriminator(neg_samples)
neg_error = criterion(neg_predicted, get_nearzeros(batches))
discriminator_error = pos_error + neg_error
discriminator_error.backward()
discrim_optim.step()
# Training generator network
generat_optim.zero_grad()
gen_samples = generator(noise(batches, 128))
gen_predicted = discriminator(gen_samples)
generator_error = criterion(gen_predicted, get_nearones(batches))
generator_error.backward()
generat_optim.step()
derr += discriminator_error
gerr += generator_error
dx += pos_predicted.data.mean()
gx += neg_predicted.data.mean()
print(f'Epoch:{epoch}.. D x : {dx/10:.4f}.. G x: {gx/10:.4f}.. D err : {derr/10:.4f}.. G err: {gerr/10:.4f}')
torch.save(discriminator, 'discriminator_model.pt')
torch.save(generator, 'generator_model.pt')
derrors.append(dx/10)
gerrors.append(gx/10)
if epoch %10 ==0:
plotimage(is_cuda)
# Plotting the errors
plt.plot(range(2000),[x.item() for x in derrors], color='r')
plt.plot(range(2000),[y.item() for y in gerrors], color='g')
plt.show()
# Images created by Generator network
for i in range(10):
plotimage(is_cuda)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment