Skip to content

Instantly share code, notes, and snippets.

View ethanyanjiali's full-sized avatar

Ethan Yanjia Li ethanyanjiali

View GitHub Profile
@ethanyanjiali
ethanyanjiali / cyclegan_train_generator.py
Created June 6, 2019 06:00
cyclegan_train_generator
@tf.function
def train_generator(images_a, images_b):
real_a = images_a
real_b = images_b
with tf.GradientTape() as tape:
# Use real B to generate B should be identical
identity_a2b = generator_a2b(real_b, training=True)
identity_b2a = generator_b2a(real_a, training=True)
loss_identity_a2b = calc_identity_loss(identity_a2b, real_b)
loss_identity_b2a = calc_identity_loss(identity_b2a, real_a)
def calc_gan_loss(prediction, is_real):
# Typical GAN loss to set objectives for generator and discriminator
if is_real:
return mse_loss(prediction, tf.ones_like(prediction))
else:
return mse_loss(prediction, tf.zeros_like(prediction))
def calc_cycle_loss(reconstructed_images, real_images):
# Cycle loss to make sure reconstructed image looks real
return mae_loss(reconstructed_images, real_images)
def make_generator_model(n_blocks):
# 6 residual blocks
# c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3
# 9 residual blocks
# c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3
model = tf.keras.Sequential()
# Encoding
model.add(ReflectionPad2d(3, input_shape=(256, 256, 3)))
model.add(tf.keras.layers.Conv2D(64, (7, 7), strides=(1, 1), padding='valid', use_bias=False))
gen_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches, DECAY_EPOCHS * total_batches)
dis_lr_scheduler = LinearDecay(LEARNING_RATE, EPOCHS * total_batches, DECAY_EPOCHS * total_batches)
optimizer_gen = tf.keras.optimizers.Adam(gen_lr_scheduler, BETA_1)
optimizer_dis = tf.keras.optimizers.Adam(dis_lr_scheduler, BETA_1)