Skip to content

Instantly share code, notes, and snippets.

@jinglescode
Created January 23, 2021 13:26
Show Gist options
  • Save jinglescode/a472e6c2fbec901446c3bcd98180f65b to your computer and use it in GitHub Desktop.
Save jinglescode/a472e6c2fbec901446c3bcd98180f65b to your computer and use it in GitHub Desktop.
for epoch in range(n_epochs):
for real_samples, _ in tqdm(dataloader):
batch_size = len(real_samples)
real_samples = real_samples.view(batch_size, -1).to(device)
# train discriminator
discriminator_optim.zero_grad()
discriminator_loss = get_discriminator_loss(generator_net, discriminator_net, criterion, real_samples, batch_size, dim_noise, device)
discriminator_loss.backward(retain_graph=True)
discriminator_optim.step()
mean_discriminator_loss += discriminator_loss.item() / display_step
# train generator
generator_optim.zero_grad()
generator_loss = get_generator_loss(generator_net, discriminator_net, criterion, batch_size, dim_noise, device)
generator_loss.backward()
generator_optim.step()
mean_generator_loss += generator_loss.item() / display_step
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment