Last active
April 19, 2020 18:53
-
-
Save matheushent/e5ac809c96bb1ab276514e7ed46517bb to your computer and use it in GitHub Desktop.
Stand alone code for tensorflow issue #38518
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow_addons as tfa | |
import tensorflow as tf | |
def get_norm_layer(norm): | |
"""Utility function to get the normalization layer | |
""" | |
if norm == None: | |
return lambda: lambda x: x | |
elif norm == 'batch_norm': | |
return tf.keras.layers.BatchNormalization | |
elif norm == 'instance_norm': | |
return tfa.layers.InstanceNormalization | |
elif norm == 'layer_norm': | |
return tf.keras.layers.LayerNormalization | |
class LinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule): | |
def __init__(self, init_learning_rate, epochs, epoch_decay, beta_1): | |
super(LinearDecay, self).__init__() | |
self.init_learning_rate = init_learning_rate | |
self.epochs = epochs | |
self.epoch_decay = epoch_decay | |
self.beta_1 = beta_1 | |
self.current_learning_rate = tf.Variable( | |
initial_value=init_learning_rate, | |
trainable=False, dtype=tf.float32 | |
) | |
# define functions to call | |
self.true_fn = lambda step: self.init_learning_rate * (1 - (1 / ((self.epochs - self.epoch_decay) * (step - self.epoch_decay)))) | |
self.false_fn = lambda: self.init_learning_rate | |
def __call__(self, step): | |
self.current_learning_rate.assign(tf.cond( | |
step >= self.epoch_decay, | |
true_fn=self.true_fn(step), | |
false_fn=self.false_fn | |
)) | |
class ConvDiscriminator: | |
"""Utility class to build the discriminator. | |
By the [paper](https://arxiv.org/abs/1703.10593v6) in section 4, the | |
generative network architecture is adopt from [Johnson et al.]( | |
https://arxiv.org/abs/1603.08155) | |
""" | |
def __init__(self, | |
input_shape=(256, 256, 3), | |
dim=64, | |
num_downsamplings=3, | |
norm='instance_norm', | |
lr_scheduler=LinearDecay(0.0001, 200, 100, 0.5)): | |
self.norm = get_norm_layer(norm) | |
self.input_shape = input_shape | |
self.dim = dim | |
self.dim_ = dim | |
self.num_downsamplings = num_downsamplings | |
self.lr_scheduler = lr_scheduler | |
# build model | |
self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr_scheduler, beta_1=self.lr_scheduler.beta_1) | |
self.model = self.build() | |
def build(self): | |
x = inputs = tf.keras.Input(shape=self.input_shape) | |
# 1 | |
x = tf.keras.layers.Conv2D(self.dim, 4, strides=2, padding='same', kernel_initializer=CONV_KERNEL_INITIALIZER)(x) | |
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x) | |
for _ in range(self.num_downsamplings - 1): | |
self.dim = min(self.dim * 2, self.dim_ * 8) | |
x = tf.keras.layers.Conv2D(self.dim, 4, strides=2, padding='same', kernel_initializer=CONV_KERNEL_INITIALIZER, use_bias=False)(x) | |
x = self.norm()(x) | |
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x) | |
# 2 | |
self.dim = min(self.dim * 2, self.dim_ * 8) | |
x = tf.keras.layers.Conv2D(self.dim, 4, strides=1, padding='same', kernel_initializer=CONV_KERNEL_INITIALIZER, use_bias=False)(x) | |
x = self.norm()(x) | |
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x) | |
# 3 | |
x = tf.keras.layers.Conv2D(1, 4, strides=1, padding='same', kernel_initializer=CONV_KERNEL_INITIALIZER)(x) | |
model = tf.keras.Model(inputs=inputs, outputs=x) | |
model.compile( | |
optimizer=self.optimizer, | |
loss='mse', | |
metrics=['accuracy'] | |
) | |
return model | |
discriminator = ConvDiscriminator() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment