Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created January 19, 2019 14:55
Show Gist options
  • Save NMZivkovic/829ac6ea72027c44a702b1592e65a469 to your computer and use it in GitHub Desktop.
Save NMZivkovic/829ac6ea72027c44a702b1592e65a469 to your computer and use it in GitHub Desktop.
def train(self, epochs, train_data, batch_size):
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
history = []
for epoch in range(epochs):
# Train Discriminator
batch_indexes = np.random.randint(0, train_data.shape[0], batch_size)
batch = train_data[batch_indexes]
latent_vector_fake = self.encoder_model.predict(batch)
latent_vector_real = np.random.normal(size=(batch_size, self.latent_dimension))
loss_real = self.discriminator_model.train_on_batch(latent_vector_real, real)
loss_fake = self.discriminator_model.train_on_batch(latent_vector_fake, fake)
discriminator_loss = 0.5 * np.add(loss_real, loss_fake)
# Train Generator
generator_loss = self.aae.train_on_batch(batch, [batch, real])
# Plot the progress
print ("---------------------------------------------------------")
print ("******************Epoch {}***************************".format(epoch))
print ("Discriminator loss: {}".format(discriminator_loss[0]))
print ("Generator loss: {}".format(generator_loss))
print ("---------------------------------------------------------")
history.append({"D":discriminator_loss[0],"G":generator_loss})
# Save images from every hundereth epoch generated images
if epoch % 100 == 0:
self._save_images(epoch)
self._plot_loss(history)
self._image_helper.makegif("generated-aae/")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment