Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic

NMZivkovic/gan_train.py

Last active Dec 16, 2018
Embed
What would you like to do?
def train(self, epochs, train_data, batch_size):
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
history = []
for epoch in range(epochs):
# Train Discriminator
batch_indexes = np.random.randint(0, train_data.shape[0], batch_size)
batch = train_data[batch_indexes]
genenerated = self._predict_noise(batch_size)
loss_real = self.discriminator_model.train_on_batch(batch, real)
loss_fake = self.discriminator_model.train_on_batch(genenerated, fake)
discriminator_loss = 0.5 * np.add(loss_real, loss_fake)
# Train Generator
noise = np.random.normal(0, 1, (batch_size, self.generator_input_dim))
generator_loss = self.gan.train_on_batch(noise, real)
# Plot the progress
print ("---------------------------------------------------------")
print ("******************Epoch {}***************************".format(epoch))
print ("Discriminator loss: {}".format(discriminator_loss[0]))
print ("Generator loss: {}".format(generator_loss))
print ("---------------------------------------------------------")
history.append({"D":discriminator_loss[0],"G":generator_loss})
# Take a snapshot every 100th epoch
if epoch % 100 == 0:
self._save_images(epoch)
self._plot_loss(history)
self._image_helper.makegif("generated/")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment