Skip to content

Instantly share code, notes, and snippets.

@deeperunderstanding
Last active July 25, 2019 20:18
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save deeperunderstanding/529fdb7d605ba0bbf7324f4338842d30 to your computer and use it in GitHub Desktop.
Save deeperunderstanding/529fdb7d605ba0bbf7324f4338842d30 to your computer and use it in GitHub Desktop.
batches = 10000
batch_size=64
losses_disc = []
losses_disc_cat = []
losses_ae = []
losses_val = []
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
def discriminator_training(discriminator, real, fake):
def train(real_samples, fake_samples):
discriminator.trainable = True
loss_real = discriminator.train_on_batch(real_samples, real)
loss_fake = discriminator.train_on_batch(fake_samples, fake)
loss = np.add(loss_real, loss_fake) * 0.5
discriminator.trainable = False
return loss
return train
train_prior_discriminator = discriminator_training(prior_discriminator, real, fake)
train_cat_discriminator = discriminator_training(cat_discriminator, real, fake)
pbar = tqdm(range(batches))
for _ in pbar:
ids = np.random.randint(0, train_x.shape[0], batch_size)
signals = train_x[ids]
_, latent_fake, category_fake, _ = encoder.predict(signals)
latent_real = sample_normal(latent_dim, batch_size)
category_real = sample_categories(cat_dim, batch_size)
prior_loss = train_prior_discriminator(latent_real, latent_fake)
cat_loss = train_cat_discriminator(category_real, category_fake)
losses_disc.append(prior_loss)
losses_disc_cat.append(cat_loss)
encoder_loss = autoencoder.train_on_batch(signals, [signals, real, real])
losses_ae.append(encoder_loss)
val_loss = autoencoder.test_on_batch(signals, [signals, real, real])
losses_val.append(val_loss)
pbar.set_description("[Acc. Prior/Cat: %.2f%% / %.2f%%] [MSE train/val: %f / %f]"
% (100*prior_loss[1], 100*cat_loss[1], encoder_loss[1], val_loss[1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment