-
-
Save Laurans/075e3f411db25a255068ecb0c055e753 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def discriminator(images, reuse=False): | |
filters = [64, 128, 256, 512, 256] | |
alpha = 0.18 | |
with tf.variable_scope('discriminator', reuse=reuse): | |
x = tf.layers.conv2d(images, filters[0], 5, strides=2, padding='same', activation=None, | |
kernel_initializer=tf.contrib.layers.xavier_initializer_conv2d()) | |
x = tf.maximum(alpha * x, x) | |
for size in filters[1:]: | |
x = tf.layers.conv2d(x, size, 5, strides=2, padding='same', activation=None, | |
kernel_initializer=tf.contrib.layers.xavier_initializer_conv2d()) | |
bn = tf.layers.batch_normalization(x, training=True) | |
relu = tf.maximum(alpha * bn, bn) | |
x = tf.layers.dropout(relu, 0.3) | |
# Flatten | |
flat = tf.reshape(x, (-1, np.prod(relu.shape[1:]))) | |
logits = tf.layers.dense(flat, 1) | |
out = tf.sigmoid(logits) | |
return out, logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generator(z, out_channel_dim, is_train=True): | |
alpha = 0.1 | |
filters = [1024, 512, 256, 128, 64, 32] | |
strides = [2, 2, 2, 2, 1] | |
with tf.variable_scope('generator', reuse=not is_train): | |
# First fully connected layer | |
x = tf.layers.dense(z, 7*7*filters[0]) | |
# Reshape it to start the convolutional stack | |
x = tf.reshape(x, (-1, 7, 7, filters[0])) | |
bn = tf.layers.batch_normalization(x, training=is_train) | |
x = tf.maximum(alpha * bn, bn) | |
for size, stride in zip(filters[1:], strides): | |
noise = tf.random_normal(shape=tf.shape(x), mean=0.0, stddev=0.2, dtype=tf.float32) | |
x = x + noise | |
x = tf.layers.conv2d_transpose(x, size, 5, strides=stride, padding='same', activation=None, | |
kernel_initializer=tf.contrib.layers.xavier_initializer_conv2d()) | |
bn = tf.layers.batch_normalization(x, training=is_train) | |
relu = tf.maximum(alpha * bn, bn) | |
x = tf.layers.dropout(relu, 0.3) | |
logits = tf.layers.conv2d_transpose(x, out_channel_dim, 5, strides=2, padding='same', activation=None, | |
kernel_initializer=tf.contrib.layers.xavier_initializer_conv2d()) | |
out = tf.tanh(logits) | |
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_loss(input_real, input_z, out_channel_dim): | |
g_model = generator(input_z, out_channel_dim) | |
_ , d_logits_real = discriminator(input_real) | |
_ , d_logits_fake = discriminator(g_model, reuse=True) | |
smooth1 = tf.random_uniform(tf.shape(d_logits_real), minval=0, maxval=0.2) | |
smooth0 = tf.random_uniform(tf.shape(d_logits_fake), minval=0, maxval=0.2) | |
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( | |
logits=d_logits_real, labels=tf.ones_like(d_logits_real) - smooth1)) | |
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( | |
logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake) + smooth0)) | |
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( | |
logits=d_logits_fake, labels=tf.ones_like(d_logits_fake))) | |
d_loss = d_loss_real + d_loss_fake | |
return d_loss, g_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
while train_loss_d > 2: | |
# Train discrimator and get loss in train_loss_d | |
while train_loss_g > train_loss_d: | |
# Train generator and get loss in train_loss_g |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment