Skip to content

Instantly share code, notes, and snippets.

@RileyLazarou
Created June 21, 2020 22:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save RileyLazarou/e4cc67855fc2013acb26db8c165c0514 to your computer and use it in GitHub Desktop.
Save RileyLazarou/e4cc67855fc2013acb26db8c165c0514 to your computer and use it in GitHub Desktop.
vanilla gan main
def main():
from time import time
epochs = 600
batches = 10
generator = Generator(1)
discriminator = Discriminator(1, [64, 32, 1])
noise_fn = lambda x: torch.rand((x, 1), device='cpu')
data_fn = lambda x: torch.randn((x, 1), device='cpu')
gan = VanillaGAN(generator, discriminator, noise_fn, data_fn, device='cpu')
loss_g, loss_d_real, loss_d_fake = [], [], []
start = time()
for epoch in range(epochs):
loss_g_running, loss_d_real_running, loss_d_fake_running = 0, 0, 0
for batch in range(batches):
lg_, (ldr_, ldf_) = gan.train_step()
loss_g_running += lg_
loss_d_real_running += ldr_
loss_d_fake_running += ldf_
loss_g.append(loss_g_running / batches)
loss_d_real.append(loss_d_real_running / batches)
loss_d_fake.append(loss_d_fake_running / batches)
print(f"Epoch {epoch+1}/{epochs} ({int(time() - start)}s):"
f" G={loss_g[-1]:.3f},"
f" Dr={loss_d_real[-1]:.3f},"
f" Df={loss_d_fake[-1]:.3f}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment