Created
December 28, 2017 23:41
-
-
Save planetA/9958c986cea1709cc6e6b8260b5e6b9a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Conv2DWrap: | |
def __batch_normalization(self, x): | |
x = tf.layers.batch_normalization(x, training=self.training) | |
return x | |
def __init__(self, inp, training, alpha, drop_rate=None): | |
self.last_layer = inp | |
self.current_layer = 0 | |
self.alpha = alpha | |
self.drop_rate = drop_rate | |
self.training = training | |
def dropout(self): | |
x = self.last_layer | |
self.current_layer += 1 | |
x = tf.layers.dropout(x, rate=self.drop_rate, name='dropout_{}'.format(self.current_layer)) | |
self.last_layer = x | |
return x | |
def z_gen(self, z, filters, image_proportions): | |
"""This is the first layer for generator network | |
:param z: Randomized input | |
:param filters: Number of filters in the first layer | |
"param image_proportions: proportions of the output image | |
""" | |
# First fully connected layer | |
assert self.current_layer == 0 | |
x = self.last_layer | |
h, w = image_proportions | |
self.current_layer += 1 | |
with tf.variable_scope('z_gen_{}'.format(self.current_layer)): | |
x = tf.layers.dense(x, h * w * filters) | |
# Reshape it to start the convolutional stack | |
x = tf.reshape(x, (-1, h, w, filters)) | |
x = self.__batch_normalization(x) | |
x = tf.maximum(self.alpha * x, x) | |
self.last_layer = x | |
def transpose(self, filters, strides): | |
"""Transpose convolution for the generator""" | |
x = self.last_layer | |
self.current_layer += 1 | |
with tf.variable_scope('conv_transpose_{}'.format(self.current_layer)): | |
x = tf.layers.conv2d_transpose(x, filters, 5, strides=strides, padding='same') | |
x = self.__batch_normalization(x) | |
x = tf.maximum(self.alpha * x, x) | |
self.last_layer = x | |
return x | |
def first(self, size_mult, strides, kernel=3, padding='same'): | |
"""First layer of the discriminator network""" | |
assert self.current_layer == 0 | |
x = self.last_layer | |
self.current_layer += 1 | |
with tf.variable_scope('conv_{}'.format(self.current_layer)): | |
x = tf.layers.conv2d(x, size_mult, kernel, strides=strides, padding=padding) | |
relu = tf.maximum(self.alpha * x, x) | |
relu = tf.layers.dropout(relu, rate=self.drop_rate) | |
self.last_layer = relu | |
return relu | |
def pooled(self, size_mult, strides, kernel=3, padding='same'): | |
x = self.last_layer | |
assert x is not None | |
self.current_layer += 1 | |
with tf.variable_scope('conv_{}'.format(self.current_layer)): | |
x = tf.layers.conv2d(x, size_mult, kernel, strides=strides, padding=padding) | |
relu = tf.maximum(self.alpha * x, x) | |
pool_size = (strides, strides) | |
pool = tf.layers.max_pooling2d(relu, pool_size, (strides, strides), padding=padding) | |
self.last_layer = pool | |
return pool | |
def batched(self, size_mult, strides, kernel=3, padding='same'): | |
x = self.last_layer | |
assert x is not None | |
self.current_layer += 1 | |
with tf.variable_scope('conv_{}'.format(self.current_layer)): | |
x = tf.layers.conv2d(x, size_mult, kernel, strides=strides, padding=padding) | |
x = self.__batch_normalization(x) | |
relu = tf.maximum(self.alpha * x, x) | |
self.last_layer = relu | |
return relu | |
def generator(z, output_dim, training, reuse=False, alpha=0.2, size_mult=128): | |
with tf.variable_scope('generator', reuse=reuse): | |
conv2dt = Conv2DWrap(z, training, alpha=alpha) | |
conv2dt.z_gen(z, size_mult * 16, (4, 3)) | |
conv2dt.transpose(size_mult * 8, strides=5) | |
conv2dt.transpose(size_mult * 4, strides=5) | |
# conv2dt.transpose(size_mult * 2, strides=2) | |
x = conv2dt.transpose(size_mult * 1, strides=2) | |
# Output layer | |
logits = tf.layers.conv2d_transpose(x, output_dim, 5, strides=2, padding='same') | |
out = tf.tanh(logits) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment