Skip to content

Instantly share code, notes, and snippets.

@emuccino
Last active April 13, 2020 08:28
Show Gist options
  • Save emuccino/006d61fdf1af563f7e2ea80517663b8a to your computer and use it in GitHub Desktop.
Save emuccino/006d61fdf1af563f7e2ea80517663b8a to your computer and use it in GitHub Desktop.
Train classifiers
#function for building classifier (very similar to the discriminator)
def compile_classifier():
inputs = {}
numeric_nets = []
string_nets = []
for name in numeric_data:
numeric_input = Input(shape=(1,),name=name)
inputs[name] = numeric_input
numeric_net = GaussianNoise(0.01)(numeric_input)
numeric_nets.append(numeric_net)
for name,n_token in n_tokens.items():
string_input = Input(shape=(n_token,),name=name)
inputs[name] = string_input
string_net = GaussianNoise(0.05)(string_input)
string_net = Dense(n_embeddings[name],activation='relu',kernel_initializer='he_uniform')(string_net)
string_nets.append(string_net)
string_nets = Concatenate()(string_nets)
string_nets = BatchNormalization()(string_nets)
string_nets = [Dense(len(string_data),activation='relu',
kernel_initializer='he_uniform')(string_nets)]
net = Concatenate()(numeric_nets + string_nets)
net = BatchNormalization()(net)
for _ in range(4):
net = Dense(128, activation='relu',
kernel_initializer='he_uniform')(net)
net = BatchNormalization()(net)
outputs = Dense(2, activation='softmax',
kernel_initializer='glorot_uniform')(net)
classifier = Model(inputs=inputs, outputs=outputs)
classifier.compile(loss='categorical_crossentropy',
optimizer=Nadam(clipnorm=1.), metrics=['categorical_accuracy'])
return classifier
#classifier to be trained on real data
classifier = compile_classifier()
#classifier to be trained on combination of real and synthetic data
gan_classifier = compile_classifier()
batch_size = 512
#train classifer with real data for 1000 epochs
for _ in range(1000):
x_real, y_real = generate_real_samples(batch_size)
classifier.train_on_batch(x_real,y_real[:,1:])
#train classifer with real and synthetic data for 1000 epochs
for _ in range(1000):
#split batch into half real and half synthetic data
x_real, y_real = generate_real_samples(batch_size//2)
x_synth, y_synth = generate_synthetic_samples(batch_size//2)
x_total = {}
for key in x_real.keys():
x_total[key] = np.vstack([x_real[key],x_synth[key]])
y_total = np.vstack([y_real,y_synth])
gan_classifier.train_on_batch(x_total,y_total[:,1:])
#setup of test data for evaluating classifier results
test_inputs = {}
for name in numeric_data:
test_inputs[name] = test_df[[name]].values
for name in string_data:
test_inputs[name] = to_categorical(test_df[name].values,n_tokens[name])
test_outputs = to_categorical(test_target_df[target].values,2)
classifier_eval = classifier.evaluate(test_inputs,test_outputs)
print('classifier accuracy:',classifier_eval[1])
gan_classifier_eval = gan_classifier.evaluate(test_inputs,test_outputs)
print('gan classifier accuracy:',gan_classifier_eval[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment