Skip to content

Instantly share code, notes, and snippets.

@RileyLazarou
Last active June 22, 2020 13:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RileyLazarou/b11b40fc41215e30b2c0f1bbc8f76847 to your computer and use it in GitHub Desktop.
Save RileyLazarou/b11b40fc41215e30b2c0f1bbc8f76847 to your computer and use it in GitHub Desktop.
vanilla gan
class VanillaGAN():
def __init__(self, generator, discriminator, noise_fn, data_fn,
batch_size=32, device='cpu', lr_d=1e-3, lr_g=2e-4):
"""A GAN class for holding and training a generator and discriminator
Args:
generator: a Ganerator network
discriminator: A Discriminator network
noise_fn: function f(num: int) -> pytorch tensor, (latent vectors)
data_fn: function f(num: int) -> pytorch tensor, (real samples)
batch_size: training batch size
device: cpu or CUDA
lr_d: learning rate for the discriminator
lr_g: learning rate for the generator
"""
self.generator = generator
self.generator = self.generator.to(device)
self.discriminator = discriminator
self.discriminator = self.discriminator.to(device)
self.noise_fn = noise_fn
self.data_fn = data_fn
self.batch_size = batch_size
self.device = device
self.criterion = nn.BCELoss()
self.optim_d = optim.Adam(discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.999))
self.optim_g = optim.Adam(generator.parameters(),
lr=lr_g, betas=(0.5, 0.999))
self.target_ones = torch.ones((batch_size, 1)).to(device)
self.target_zeros = torch.zeros((batch_size, 1)).to(device)
def generate_samples(self, latent_vec=None, num=None):
"""Sample from the generator.
Args:
latent_vec: A pytorch latent vector or None
num: The number of samples to generate if latent_vec is None
If latent_vec and num are None then us self.batch_size random latent
vectors.
"""
num = self.batch_size if num is None else num
latent_vec = self.noise_fn(num) if latent_vec is None else latent_vec
with torch.no_grad():
samples = self.generator(latent_vec)
return samples
def train_step_generator(self):
"""Train the generator one step and return the loss."""
self.generator.zero_grad()
latent_vec = self.noise_fn(self.batch_size)
generated = self.generator(latent_vec)
classifications = self.discriminator(generated)
loss = self.criterion(classifications, self.target_ones)
loss.backward()
self.optim_g.step()
return loss.item()
def train_step_discriminator(self):
"""Train the discriminator one step and return the losses."""
self.discriminator.zero_grad()
# real samples
real_samples = self.data_fn(self.batch_size)
pred_real = self.discriminator(real_samples)
loss_real = self.criterion(pred_real, self.target_ones)
# generated samples
latent_vec = self.noise_fn(self.batch_size)
with torch.no_grad():
fake_samples = self.generator(latent_vec)
pred_fake = self.discriminator(fake_samples)
loss_fake = self.criterion(pred_fake, self.target_zeros)
# combine
loss = (loss_real + loss_fake) / 2
loss.backward()
self.optim_d.step()
return loss_real.item(), loss_fake.item()
def train_step(self):
"""Train both networks and return the losses."""
loss_d = self.train_step_discriminator()
loss_g = self.train_step_generator()
return loss_g, loss_d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment