Created
April 28, 2018 12:34
-
-
Save nogawanogawa/3a3e329c6c971e7137cc93c5f90978d8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def loss_calc(self): | |
# Cycle Consistency Loss | |
cyc_loss = tf.reduce_mean(tf.abs(self.x-self.cyc_x)) + tf.reduce_mean(tf.abs(self.y-self.cyc_y)) | |
# Adversarial Loss | |
# Discriminatorのloss(偽物を偽物と見分ける) | |
disc_loss_x = tf.reduce_mean(tf.squared_difference(self.fake_rec_x,1)) | |
disc_loss_y = tf.reduce_mean(tf.squared_difference(self.fake_rec_y,1)) | |
# Generatorのloss | |
g_loss_x = cyc_loss*10 + disc_loss_y | |
g_loss_y = cyc_loss*10 + disc_loss_x | |
# Discriminatorのloss(本物を本物と見分ける) | |
d_loss_x = (tf.reduce_mean(tf.square(self.fake_pool_rec_x)) + tf.reduce_mean(tf.squared_difference(self.rec_x,1)))/2.0 | |
d_loss_y = (tf.reduce_mean(tf.square(self.fake_pool_rec_y)) + tf.reduce_mean(tf.squared_difference(self.rec_y,1)))/2.0 | |
optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5) | |
self.model_vars = tf.trainable_variables() | |
d_x_vars = [var for var in self.model_vars if 'd_x' in var.name] | |
g_x_vars = [var for var in self.model_vars if 'g_x' in var.name] | |
d_y_vars = [var for var in self.model_vars if 'd_y' in var.name] | |
g_y_vars = [var for var in self.model_vars if 'g_y' in var.name] | |
self.d_x_trainer = optimizer.minimize(d_loss_x, var_list=d_x_vars) | |
self.d_y_trainer = optimizer.minimize(d_loss_y, var_list=d_y_vars) | |
self.g_x_trainer = optimizer.minimize(g_loss_x, var_list=g_x_vars) | |
self.g_y_trainer = optimizer.minimize(g_loss_y, var_list=g_y_vars) | |
self.g_x_loss_summ = tf.summary.scalar("g_x_loss", g_loss_x) | |
self.g_y_loss_summ = tf.summary.scalar("g_y_loss", g_loss_y) | |
self.d_x_loss_summ = tf.summary.scalar("d_x_loss", d_loss_x) | |
self.d_y_loss_summ = tf.summary.scalar("d_y_loss", d_loss_y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment