Created
October 19, 2017 05:22
-
-
Save chaserileyroberts/4c7efec31b30f50892c27d44f9f055be to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_gen_training(): | |
model = Model | |
sess = tf.Session() | |
gen_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='gen') | |
des_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='des') | |
before_gen = sess.run(gen_vars) | |
before_des = sess.run(des_vars) | |
# Train the generator. | |
sess.run(model.train_gen) | |
after_gen = sess.run(gen_vars) | |
after_des = sess.run(des_vars) | |
# Make sure the generator variables changed. | |
for b,a in zip(before_gen, after_gen): | |
assert (a != b).any() | |
# Make sure descriminator did NOT change. | |
for b,a in zip(before_des, after_des): | |
assert (a == b).all() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment