Skip to content

Instantly share code, notes, and snippets.

@chaserileyroberts
Created October 19, 2017 05:22
Show Gist options
  • Save chaserileyroberts/4c7efec31b30f50892c27d44f9f055be to your computer and use it in GitHub Desktop.
Save chaserileyroberts/4c7efec31b30f50892c27d44f9f055be to your computer and use it in GitHub Desktop.
def test_gen_training():
model = Model
sess = tf.Session()
gen_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='gen')
des_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='des')
before_gen = sess.run(gen_vars)
before_des = sess.run(des_vars)
# Train the generator.
sess.run(model.train_gen)
after_gen = sess.run(gen_vars)
after_des = sess.run(des_vars)
# Make sure the generator variables changed.
for b,a in zip(before_gen, after_gen):
assert (a != b).any()
# Make sure descriminator did NOT change.
for b,a in zip(before_des, after_des):
assert (a == b).all()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment