Skip to content

Instantly share code, notes, and snippets.

@eoehri
Last active December 21, 2019 18:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eoehri/bead4e397e9fd464454a6748226b5965 to your computer and use it in GitHub Desktop.
Save eoehri/bead4e397e9fd464454a6748226b5965 to your computer and use it in GitHub Desktop.
"""TensorFlow 2.0 implementation of a convolutional Autoencoder."""
import tensorflow as tf
from datetime import datetime
tf.random.set_seed(1)
batch_size = 128
epochs = 10
learning_rate = 0.001
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = tf.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = tf.reshape(x_test, (len(x_test), 28, 28, 1))
class Encoder(tf.keras.layers.Layer):
def __init__(self):
super(Encoder, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(
16, (3, 3), activation='relu', padding='same')
self.maxp1 = tf.keras.layers.MaxPooling2D((2, 2), padding='same')
self.conv2 = tf.keras.layers.Conv2D(
8, (3, 3), activation='relu', padding='same')
self.maxp2 = tf.keras.layers.MaxPooling2D((2, 2), padding='same')
self.conv3 = tf.keras.layers.Conv2D(
8, (3, 3), activation='relu', padding='same')
self.encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same')
@tf.function
def call(self, input_features):
x = self.conv1(input_features)
x = self.maxp1(x)
x = self.conv2(x)
x = self.maxp2(x)
x = self.conv3(x)
return self.encoded(x)
class Decoder(tf.keras.layers.Layer):
def __init__(self):
super(Decoder, self).__init__()
self.conv4 = tf.keras.layers.Conv2D(
8, (3, 3), activation='relu', padding='same')
self.upsample1 = tf.keras.layers.UpSampling2D((2, 2))
self.conv5 = tf.keras.layers.Conv2D(
8, (3, 3), activation='relu', padding='same')
self.upsample2 = tf.keras.layers.UpSampling2D((2, 2))
self.conv6 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu')
self.upsample3 = tf.keras.layers.UpSampling2D((2, 2))
self.decoded = tf.keras.layers.Conv2D(
1, (3, 3), activation='sigmoid', padding='same')
@tf.function
def call(self, encoded_features):
x = self.conv4(encoded_features)
x = self.upsample1(x)
x = self.conv5(x)
x = self.upsample2(x)
x = self.conv6(x)
x = self.upsample3(x)
return self.decoded(x)
class Autoencoder(tf.keras.Model):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
@tf.function
def call(self, input_features):
encoded = self.encoder(input_features)
reconstructed = self.decoder(encoded)
return reconstructed
autoencoder = Autoencoder()
opt = tf.optimizers.Adam(learning_rate=learning_rate)
# opt = tf.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
# opt = tf.optimizers.Adadelta(lr=learning_rate, rho=0.95,
# epsilon=1e-08, decay=decay_rate)
def loss(model, original):
reconstruction_error = tf.reduce_mean(
tf.square(tf.subtract(model(original), original)))
return reconstruction_error
def train(loss, model, opt, original):
with tf.GradientTape() as tape:
gradients = tape.gradient(
loss(model, original), model.trainable_variables)
gradient_variables = zip(gradients, model.trainable_variables)
opt.apply_gradients(gradient_variables)
logdir = "logs/train_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")
writer = tf.summary.create_file_writer(logdir)
with writer.as_default():
with tf.summary.record_if(True):
for epoch in range(epochs):
print("Epoch: ", epoch)
step = 0
# for step, batch_features in enumerate(x_train):
for x in range(0, len(x_train), batch_size):
step += 1
batch_features = x_train[x: x + batch_size]
train(loss, autoencoder, opt, batch_features)
loss_values = loss(autoencoder, batch_features)
original = tf.reshape(
batch_features, (batch_features.shape[0], 28, 28, 1))
reconstructed = tf.reshape(autoencoder(tf.constant(
batch_features)), (batch_features.shape[0], 28, 28, 1))
tf.summary.scalar('loss', loss_values, step=step)
tf.summary.image('original', original,
max_outputs=10, step=step)
tf.summary.image('reconstructed', reconstructed,
max_outputs=10, step=step)
if step % 10 == 0:
print("Epoch: {}, Step: {}, Loss: {}".format(
epoch, step, loss_values))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment