Skip to content

Instantly share code, notes, and snippets.

@jinglescode
Created January 23, 2021 13:25
Show Gist options
  • Save jinglescode/aeb094381364028d370942495c76c2e0 to your computer and use it in GitHub Desktop.
Save jinglescode/aeb094381364028d370942495c76c2e0 to your computer and use it in GitHub Desktop.
def get_discriminator_loss(generator, discriminator, criterion, real_samples, n_samples, dim_noise, device):
'''
Discriminator predict and get loss
Parameters:
generator:
generator network
discriminator:
discriminator network
criterion:
loss function, likely `nn.BCEWithLogitsLoss()`
real_samples:
samples from training dataset
n_samples: int
number of samples to generate
dim_noise: int
dimension of noise vector
device: string
device, cpu or cuda
Returns:
discriminator_loss:
loss scalar
'''
random_noise = get_noise(n_samples, dim_noise, device=device)
generated_samples = generator(random_noise)
discriminator_fake_pred = discriminator(generated_samples.detach())
discriminator_fake_loss = criterion(discriminator_fake_pred, torch.zeros_like(discriminator_fake_pred))
discriminator_real_pred = discriminator(real_samples)
discriminator_real_loss = criterion(discriminator_real_pred, torch.ones_like(discriminator_real_pred))
discriminator_loss = (discriminator_fake_loss + discriminator_real_loss) / 2
return discriminator_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment