Skip to content

Instantly share code, notes, and snippets.

@jinglescode
Created January 23, 2021 13:26
Show Gist options
  • Save jinglescode/993d28df4be3204ec8fa4e141809c0ca to your computer and use it in GitHub Desktop.
Save jinglescode/993d28df4be3204ec8fa4e141809c0ca to your computer and use it in GitHub Desktop.
def get_generator_loss(generator, discriminator, criterion, n_samples, dim_noise, device):
'''
Generator generates and get discriminator's loss
Parameters:
generator:
generator network
discriminator:
discriminator network
criterion:
loss function, likely `nn.BCEWithLogitsLoss()`
n_samples: int
number of samples to generate
dim_noise: int
dimension of noise vector
device: string
device, cpu or cuda
Returns:
generator_loss:
loss scalar
'''
random_noise = get_noise(n_samples, dim_noise, device=device)
generated_samples = generator(random_noise)
discriminator_fake_pred = discriminator(generated_samples)
generator_loss = criterion(discriminator_fake_pred, torch.ones_like(discriminator_fake_pred))
return generator_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment