Skip to content

Instantly share code, notes, and snippets.

@emuccino
Last active April 13, 2020 09:35
Show Gist options
  • Save emuccino/bb14e6c18958210d8cdab10ce4f82cc1 to your computer and use it in GitHub Desktop.
Save emuccino/bb14e6c18958210d8cdab10ce4f82cc1 to your computer and use it in GitHub Desktop.
Train GAN
import itertools
import matplotlib.pyplot as plt
def train_gan(n_epochs,n_batch,n_plot,n_eval):
#discriminator/generator training logs
disc_loss_hist = []
gen_loss_hist = []
for epoch in range(n_epochs):
if epoch%100 == 0:
print(epoch,end=' ')
#enable discriminator training
discriminator.trainable = True
#sample equal portions of real/synthetic data
x_real, y_real = generate_real_samples(int(n_batch / 2))
x_synth, y_synth = generate_synthetic_samples(int(n_batch / 2))
x_total = {}
for key in x_real.keys():
x_total[key] = np.vstack([x_real[key],x_synth[key]])
y_total = np.vstack([y_real,y_synth])
#train discriminator
hist = discriminator.train_on_batch(x_total, y_total)
disc_loss_hist.append(hist)
discriminator.trainable = False
x_gan, y_gan = generate_latent_samples(n_batch)
#train generator
hist = gan.train_on_batch(x_gan, y_gan)
gen_loss_hist.append(hist)
#after set number of epochs, evaluate GAN training progress
if (epoch+1) % n_eval == 0:
print('\n')
#pull real and synthetic data to compare distributions and relationships
x_real, _ = generate_real_samples(int(n_plot / 2))
x_synth, _ = generate_synthetic_samples(int(n_plot / 2))
for name,n_token in n_tokens.items():
x_real[name] = x_real[name].argmax(1).reshape(-1,1)
x_synth[name] = x_synth[name].argmax(1).reshape(-1,1)
print('numeric data')
for i,name1 in enumerate(numeric_data):
print(name1)
plt.hist([x_real[name1].flatten(),x_synth[name1].flatten()],
bins=16) #compare data distributions
plt.legend(['Real','Synthetic'])
plt.show()
for name2 in numeric_data[i+1:]:
print(name1,name2)
plt.scatter(x_real[name1],x_real[name2],s=1) #compare data realtionships
plt.scatter(x_synth[name1],x_synth[name2],s=1)
plt.legend(['Real','Synthetic'])
plt.show()
print('string data')
for i,name1 in enumerate(string_data):
print(name1)
plt.hist([x_real[name1].flatten(),x_synth[name1].flatten()],
bins=n_tokens[name1]) #compare data distributions
plt.legend(['Real','Synthetic'])
plt.show()
for name2 in string_data[i+1:]:
print(name1,name2)
#create numerical index to represent combinations of tokens
lookup = {tup:p for p,tup in enumerate(itertools.product(range(n_tokens[name1]),
range(n_tokens[name2])))}
hist_real = [lookup[tuple(x)] for x in np.hstack([x_real[name1],x_real[name2]])]
hist_synth = [lookup[tuple(x)] for x in np.hstack([x_synth[name1],x_synth[name2]])]
plt.hist([hist_real,hist_synth],
bins=len(set(hist_real+hist_synth)),
color=['blue','orange']) #compare data realtionships
plt.legend(['Real','Synthetic'])
plt.show()
#plot loss history
print('loss history')
plt.plot(disc_loss_hist,linewidth=2)
plt.plot(gen_loss_hist,linewidth=2)
plt.legend(['Discriminator','Generator'])
plt.show()
print('\n')
n_epochs = 3000
n_batch = 1024*16
n_eval = 500
n_plot = 2048
train_gan(n_epochs,n_batch,n_plot,n_eval)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment