Skip to content

Instantly share code, notes, and snippets.

View RileyLazarou's full-sized avatar

Riley Lazarou RileyLazarou

  • Flatland Data Solutions
  • Saskatoon, Canada
View GitHub Profile
@RileyLazarou
RileyLazarou / vanilla_gan_discriminator.py
Last active June 22, 2020 13:54
vanilla gan discriminator
class Discriminator(nn.Module):
def __init__(self, input_dim, layers):
"""A discriminator for discerning real from generated samples.
params:
input_dim (int): width of the input
layers (List[int]): A list of layer widths including output width
Output activation is Sigmoid.
"""
@RileyLazarou
RileyLazarou / vanilla_gan.py
Last active June 22, 2020 13:55
vanilla gan
class VanillaGAN():
def __init__(self, generator, discriminator, noise_fn, data_fn,
batch_size=32, device='cpu', lr_d=1e-3, lr_g=2e-4):
"""A GAN class for holding and training a generator and discriminator
Args:
generator: a Ganerator network
discriminator: A Discriminator network
noise_fn: function f(num: int) -> pytorch tensor, (latent vectors)
data_fn: function f(num: int) -> pytorch tensor, (real samples)
@RileyLazarou
RileyLazarou / vanilla_gan_main.py
Created June 21, 2020 22:01
vanilla gan main
def main():
from time import time
epochs = 600
batches = 10
generator = Generator(1)
discriminator = Discriminator(1, [64, 32, 1])
noise_fn = lambda x: torch.rand((x, 1), device='cpu')
data_fn = lambda x: torch.randn((x, 1), device='cpu')
gan = VanillaGAN(generator, discriminator, noise_fn, data_fn, device='cpu')
loss_g, loss_d_real, loss_d_fake = [], [], []