Last active
December 21, 2019 18:51
-
-
Save eoehri/bead4e397e9fd464454a6748226b5965 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""TensorFlow 2.0 implementation of a convolutional Autoencoder.""" | |
import tensorflow as tf | |
from datetime import datetime | |
tf.random.set_seed(1) | |
batch_size = 128 | |
epochs = 10 | |
learning_rate = 0.001 | |
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data() | |
x_train = x_train.astype('float32') / 255. | |
x_test = x_test.astype('float32') / 255. | |
x_train = tf.reshape(x_train, (len(x_train), 28, 28, 1)) | |
x_test = tf.reshape(x_test, (len(x_test), 28, 28, 1)) | |
class Encoder(tf.keras.layers.Layer): | |
def __init__(self): | |
super(Encoder, self).__init__() | |
self.conv1 = tf.keras.layers.Conv2D( | |
16, (3, 3), activation='relu', padding='same') | |
self.maxp1 = tf.keras.layers.MaxPooling2D((2, 2), padding='same') | |
self.conv2 = tf.keras.layers.Conv2D( | |
8, (3, 3), activation='relu', padding='same') | |
self.maxp2 = tf.keras.layers.MaxPooling2D((2, 2), padding='same') | |
self.conv3 = tf.keras.layers.Conv2D( | |
8, (3, 3), activation='relu', padding='same') | |
self.encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same') | |
@tf.function | |
def call(self, input_features): | |
x = self.conv1(input_features) | |
x = self.maxp1(x) | |
x = self.conv2(x) | |
x = self.maxp2(x) | |
x = self.conv3(x) | |
return self.encoded(x) | |
class Decoder(tf.keras.layers.Layer): | |
def __init__(self): | |
super(Decoder, self).__init__() | |
self.conv4 = tf.keras.layers.Conv2D( | |
8, (3, 3), activation='relu', padding='same') | |
self.upsample1 = tf.keras.layers.UpSampling2D((2, 2)) | |
self.conv5 = tf.keras.layers.Conv2D( | |
8, (3, 3), activation='relu', padding='same') | |
self.upsample2 = tf.keras.layers.UpSampling2D((2, 2)) | |
self.conv6 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu') | |
self.upsample3 = tf.keras.layers.UpSampling2D((2, 2)) | |
self.decoded = tf.keras.layers.Conv2D( | |
1, (3, 3), activation='sigmoid', padding='same') | |
@tf.function | |
def call(self, encoded_features): | |
x = self.conv4(encoded_features) | |
x = self.upsample1(x) | |
x = self.conv5(x) | |
x = self.upsample2(x) | |
x = self.conv6(x) | |
x = self.upsample3(x) | |
return self.decoded(x) | |
class Autoencoder(tf.keras.Model): | |
def __init__(self): | |
super(Autoencoder, self).__init__() | |
self.encoder = Encoder() | |
self.decoder = Decoder() | |
@tf.function | |
def call(self, input_features): | |
encoded = self.encoder(input_features) | |
reconstructed = self.decoder(encoded) | |
return reconstructed | |
autoencoder = Autoencoder() | |
opt = tf.optimizers.Adam(learning_rate=learning_rate) | |
# opt = tf.optimizers.SGD(learning_rate=learning_rate, momentum=0.9) | |
# opt = tf.optimizers.Adadelta(lr=learning_rate, rho=0.95, | |
# epsilon=1e-08, decay=decay_rate) | |
def loss(model, original): | |
reconstruction_error = tf.reduce_mean( | |
tf.square(tf.subtract(model(original), original))) | |
return reconstruction_error | |
def train(loss, model, opt, original): | |
with tf.GradientTape() as tape: | |
gradients = tape.gradient( | |
loss(model, original), model.trainable_variables) | |
gradient_variables = zip(gradients, model.trainable_variables) | |
opt.apply_gradients(gradient_variables) | |
logdir = "logs/train_data/" + datetime.now().strftime("%Y%m%d-%H%M%S") | |
writer = tf.summary.create_file_writer(logdir) | |
with writer.as_default(): | |
with tf.summary.record_if(True): | |
for epoch in range(epochs): | |
print("Epoch: ", epoch) | |
step = 0 | |
# for step, batch_features in enumerate(x_train): | |
for x in range(0, len(x_train), batch_size): | |
step += 1 | |
batch_features = x_train[x: x + batch_size] | |
train(loss, autoencoder, opt, batch_features) | |
loss_values = loss(autoencoder, batch_features) | |
original = tf.reshape( | |
batch_features, (batch_features.shape[0], 28, 28, 1)) | |
reconstructed = tf.reshape(autoencoder(tf.constant( | |
batch_features)), (batch_features.shape[0], 28, 28, 1)) | |
tf.summary.scalar('loss', loss_values, step=step) | |
tf.summary.image('original', original, | |
max_outputs=10, step=step) | |
tf.summary.image('reconstructed', reconstructed, | |
max_outputs=10, step=step) | |
if step % 10 == 0: | |
print("Epoch: {}, Step: {}, Loss: {}".format( | |
epoch, step, loss_values)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment