Last active
June 22, 2020 13:55
-
-
Save RileyLazarou/b11b40fc41215e30b2c0f1bbc8f76847 to your computer and use it in GitHub Desktop.
vanilla gan
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VanillaGAN(): | |
def __init__(self, generator, discriminator, noise_fn, data_fn, | |
batch_size=32, device='cpu', lr_d=1e-3, lr_g=2e-4): | |
"""A GAN class for holding and training a generator and discriminator | |
Args: | |
generator: a Ganerator network | |
discriminator: A Discriminator network | |
noise_fn: function f(num: int) -> pytorch tensor, (latent vectors) | |
data_fn: function f(num: int) -> pytorch tensor, (real samples) | |
batch_size: training batch size | |
device: cpu or CUDA | |
lr_d: learning rate for the discriminator | |
lr_g: learning rate for the generator | |
""" | |
self.generator = generator | |
self.generator = self.generator.to(device) | |
self.discriminator = discriminator | |
self.discriminator = self.discriminator.to(device) | |
self.noise_fn = noise_fn | |
self.data_fn = data_fn | |
self.batch_size = batch_size | |
self.device = device | |
self.criterion = nn.BCELoss() | |
self.optim_d = optim.Adam(discriminator.parameters(), | |
lr=lr_d, betas=(0.5, 0.999)) | |
self.optim_g = optim.Adam(generator.parameters(), | |
lr=lr_g, betas=(0.5, 0.999)) | |
self.target_ones = torch.ones((batch_size, 1)).to(device) | |
self.target_zeros = torch.zeros((batch_size, 1)).to(device) | |
def generate_samples(self, latent_vec=None, num=None): | |
"""Sample from the generator. | |
Args: | |
latent_vec: A pytorch latent vector or None | |
num: The number of samples to generate if latent_vec is None | |
If latent_vec and num are None then us self.batch_size random latent | |
vectors. | |
""" | |
num = self.batch_size if num is None else num | |
latent_vec = self.noise_fn(num) if latent_vec is None else latent_vec | |
with torch.no_grad(): | |
samples = self.generator(latent_vec) | |
return samples | |
def train_step_generator(self): | |
"""Train the generator one step and return the loss.""" | |
self.generator.zero_grad() | |
latent_vec = self.noise_fn(self.batch_size) | |
generated = self.generator(latent_vec) | |
classifications = self.discriminator(generated) | |
loss = self.criterion(classifications, self.target_ones) | |
loss.backward() | |
self.optim_g.step() | |
return loss.item() | |
def train_step_discriminator(self): | |
"""Train the discriminator one step and return the losses.""" | |
self.discriminator.zero_grad() | |
# real samples | |
real_samples = self.data_fn(self.batch_size) | |
pred_real = self.discriminator(real_samples) | |
loss_real = self.criterion(pred_real, self.target_ones) | |
# generated samples | |
latent_vec = self.noise_fn(self.batch_size) | |
with torch.no_grad(): | |
fake_samples = self.generator(latent_vec) | |
pred_fake = self.discriminator(fake_samples) | |
loss_fake = self.criterion(pred_fake, self.target_zeros) | |
# combine | |
loss = (loss_real + loss_fake) / 2 | |
loss.backward() | |
self.optim_d.step() | |
return loss_real.item(), loss_fake.item() | |
def train_step(self): | |
"""Train both networks and return the losses.""" | |
loss_d = self.train_step_discriminator() | |
loss_g = self.train_step_generator() | |
return loss_g, loss_d |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment