Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created February 10, 2019 18:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/d49ab6d82e5edef32c1953640131ed61 to your computer and use it in GitHub Desktop.
Save NMZivkovic/d49ab6d82e5edef32c1953640131ed61 to your computer and use it in GitHub Desktop.
def train(self, epochs, batch_size, train_data_path):
real = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
history = []
for epoch in range(epochs):
for i, (imagesX, imagesY) in enumerate(self._image_helper.load_batch_of_train_images(train_data_path, batch_size)):
print ("---------------------------------------------------------")
print ("******************Epoch {} | Batch {}***************************".format(epoch, i))
print("Generate images...")
fakeY = self._generatorXY.predict(imagesX)
fakeX = self._generatorYX.predict(imagesY)
print("Train Discriminators...")
discriminatorX_loss_real = self._discriminatorX.train_on_batch(imagesX, real)
discriminatorX_loss_fake = self._discriminatorX.train_on_batch(fakeX, fake)
discriminatorX_loss = 0.5 * np.add(discriminatorX_loss_real, discriminatorX_loss_fake)
discriminatorY_loss_real = self._discriminatorY.train_on_batch(imagesY, real)
discriminatorY_loss_fake = self._discriminatorY.train_on_batch(fakeY, fake)
discriminatorY_loss = 0.5 * np.add(discriminatorY_loss_real, discriminatorY_loss_fake)
mean_discriminator_loss = 0.5 * np.add(discriminatorX_loss, discriminatorY_loss)
print("Train Generators...")
generator_loss = self.gan.train_on_batch([imagesX, imagesY],
[real, real,
imagesX, imagesY,
imagesX, imagesY])
print ("Discriminator loss: {}".format(mean_discriminator_loss[0]))
print ("Generator loss: {}".format(generator_loss[0]))
print ("---------------------------------------------------------")
history.append({"D":mean_discriminator_loss[0],"G":generator_loss})
if i%100 ==0:
self._save_images("{}_{}".format(epoch, i), train_data_path)
self._plot_loss(history)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment