Skip to content

Instantly share code, notes, and snippets.

@ethanyanjiali
Created June 6, 2019 05:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ethanyanjiali/113b728f52a46cb6dda1c627c07d81aa to your computer and use it in GitHub Desktop.
Save ethanyanjiali/113b728f52a46cb6dda1c627c07d81aa to your computer and use it in GitHub Desktop.
cyclegan_generator
def make_generator_model(n_blocks):
# 6 residual blocks
# c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3
# 9 residual blocks
# c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3
model = tf.keras.Sequential()
# Encoding
model.add(ReflectionPad2d(3, input_shape=(256, 256, 3)))
model.add(tf.keras.layers.Conv2D(64, (7, 7), strides=(1, 1), padding='valid', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.ReLU())
model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.ReLU())
model.add(tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.ReLU())
# Transformation
for i in range(n_blocks):
model.add(ResNetBlock(256))
# Decoding
model.add(tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.ReLU())
model.add(tf.keras.layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.ReLU())
model.add(ReflectionPad2d(3))
model.add(tf.keras.layers.Conv2D(3, (7, 7), strides=(1, 1), padding='valid', activation='tanh'))
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment