Skip to content

Instantly share code, notes, and snippets.

@ttchengab
Created June 21, 2021 08:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ttchengab/0a8b5820043c6352f5cbcb7764f2eb62 to your computer and use it in GitHub Desktop.
Save ttchengab/0a8b5820043c6352f5cbcb7764f2eb62 to your computer and use it in GitHub Desktop.
"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""
for epoch in range(epochs):
for idx, (imgs, _) in enumerate(train_loader):
idx += 1
# Training the discriminator
# Real inputs are actual images of the MNIST dataset
# Fake inputs are from the generator
# Real inputs should be classified as 1 and fake as 0
real_inputs = imgs.to(device)
real_outputs = D(real_inputs)
real_label = torch.ones(real_inputs.shape[0], 1).to(device)
noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
noise = noise.to(device)
fake_inputs = G(noise)
fake_outputs = D(fake_inputs)
fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)
outputs = torch.cat((real_outputs, fake_outputs), 0)
targets = torch.cat((real_label, fake_label), 0)
D_loss = loss(outputs, targets)
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# Training the generator
# For generator, goal is to make the discriminator believe everything is 1
noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
noise = noise.to(device)
fake_inputs = G(noise)
fake_outputs = D(fake_inputs)
fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
G_loss = loss(fake_outputs, fake_targets)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
if idx % 100 == 0 or idx == len(train_loader):
print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))
if (epoch+1) % 10 == 0:
torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
print('Model saved.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment