Skip to content

Instantly share code, notes, and snippets.

@harsh-99
Last active February 27, 2020 16:12
Show Gist options
  • Save harsh-99/211fcfd7eab75a9821c31e6ca75d5bcd to your computer and use it in GitHub Desktop.
Save harsh-99/211fcfd7eab75a9821c31e6ca75d5bcd to your computer and use it in GitHub Desktop.
#To train the Discriminator
output_d_real = discriminator(real_images)
d_real_loss = criterion(output_d_real, real_labels)
z = torch.randn(batch_size, random_size).to(device)
fake_images = generator(z)
output_d_fake = discriminator(fake_images)
d_fake_loss = criterion(output_d_fake, fake_labels)
d_loss = d_real_loss + d_fake_loss
optimizer_d.zero_grad()
d_loss.backward()
optimizer_d.step()#this is going to update only parameters of discriminator
#to train the generator
# Input to generator is a noise of size random_size
z = torch.randn(batch_size, random_size)
output_image = generator(z)
output_discriminator = discriminator(output_image)
#to train the generator the output of this should be compared with real_labels.
#so we compare the output by real label.
g_loss = criterion(outputs, real_labels)
optimizer_g.zero_grad()
g_loss.backward()
optimizer_g.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment