Skip to content

Instantly share code, notes, and snippets.

@ethanyanjiali
Created June 6, 2019 06:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ethanyanjiali/6f2d4ca0e6d22a02b1c9e1e4c5b39071 to your computer and use it in GitHub Desktop.
Save ethanyanjiali/6f2d4ca0e6d22a02b1c9e1e4c5b39071 to your computer and use it in GitHub Desktop.
cyclegan_training
def train_step(images_a, images_b, epoch, step):
fake_a2b, fake_b2a, gen_loss_dict = train_generator(images_a, images_b)
fake_b2a_from_pool = fake_pool_b2a.query(fake_b2a)
fake_a2b_from_pool = fake_pool_a2b.query(fake_a2b)
dis_loss_dict = train_discriminator(images_a, images_b, fake_a2b_from_pool, fake_b2a_from_pool)
def train(dataset, epochs):
for epoch in range(checkpoint.epoch+1, epochs+1):
for (step, batch) in enumerate(dataset):
train_step(batch[0], batch[1], epoch, step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment