-
-
Save ttchengab/0a8b5820043c6352f5cbcb7764f2eb62 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Network training procedure | |
Every step both the loss for disciminator and generator is updated | |
Discriminator aims to classify reals and fakes | |
Generator aims to generate images as realistic as possible | |
""" | |
for epoch in range(epochs): | |
for idx, (imgs, _) in enumerate(train_loader): | |
idx += 1 | |
# Training the discriminator | |
# Real inputs are actual images of the MNIST dataset | |
# Fake inputs are from the generator | |
# Real inputs should be classified as 1 and fake as 0 | |
real_inputs = imgs.to(device) | |
real_outputs = D(real_inputs) | |
real_label = torch.ones(real_inputs.shape[0], 1).to(device) | |
noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5 | |
noise = noise.to(device) | |
fake_inputs = G(noise) | |
fake_outputs = D(fake_inputs) | |
fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device) | |
outputs = torch.cat((real_outputs, fake_outputs), 0) | |
targets = torch.cat((real_label, fake_label), 0) | |
D_loss = loss(outputs, targets) | |
D_optimizer.zero_grad() | |
D_loss.backward() | |
D_optimizer.step() | |
# Training the generator | |
# For generator, goal is to make the discriminator believe everything is 1 | |
noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5 | |
noise = noise.to(device) | |
fake_inputs = G(noise) | |
fake_outputs = D(fake_inputs) | |
fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device) | |
G_loss = loss(fake_outputs, fake_targets) | |
G_optimizer.zero_grad() | |
G_loss.backward() | |
G_optimizer.step() | |
if idx % 100 == 0 or idx == len(train_loader): | |
print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item())) | |
if (epoch+1) % 10 == 0: | |
torch.save(G, 'Generator_epoch_{}.pth'.format(epoch)) | |
print('Model saved.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment