Skip to content

Instantly share code, notes, and snippets.

@shifdz
Last active September 5, 2021 14:32
Show Gist options
  • Save shifdz/118ba79678712a10d47e087a5dbd4f93 to your computer and use it in GitHub Desktop.
Save shifdz/118ba79678712a10d47e087a5dbd4f93 to your computer and use it in GitHub Desktop.
import tensorflow.keras.backend as K
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
tfk = tf.keras
tfkl = tf.keras.layers
tfpl = tfp.layers
tfd = tfp.distributions
if tf.test.gpu_device_name() != '/device:GPU:0':
print('WARNING: GPU device not found.')
else:
print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))
datasets, datasets_info = tfds.load(name='mnist',
with_info=True,
as_supervised=False)
def _preprocess(sample):
image = tf.cast(sample['image'], tf.float32) / 255. # Scale to unit interval.
image = image < tf.random.uniform(tf.shape(image)) # Randomly binarize.
image = tf.reshape(image, [1, 28, 28])
return image, image
train_dataset = (datasets['train']
.map(_preprocess)
.batch(256)
.prefetch(tf.data.AUTOTUNE)
.shuffle(int(10e3)))
eval_dataset = (datasets['test']
.map(_preprocess)
.batch(256)
.prefetch(tf.data.AUTOTUNE))
input_shape = (1,28,28)
encoded_size = 16
base_depth = 32
event_shape = [1]
num_components = 5
params_size = tfpl.MixtureSameFamily.params_size( num_components,
component_params_size=tfpl.IndependentNormal.params_size(event_shape))
encoder = tfk.Sequential([
tfkl.InputLayer(input_shape=input_shape),
tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
tfkl.Conv2D(base_depth, 5, strides=1,data_format="channels_first",
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2D(base_depth, 5, strides=2,data_format="channels_first",
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2D(2 * base_depth, 5, strides=1,data_format="channels_first",
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2D(2 * base_depth, 5, strides=2,data_format="channels_first",
padding='same', activation=tf.nn.leaky_relu),
tfkl.Conv2D(4 * encoded_size, 7, strides=1,data_format="channels_first",
padding='valid', activation=tf.nn.leaky_relu),
tfkl.Flatten(),
tfkl.Dense(params_size, activation=None),
tfpl.MixtureSameFamily(num_components, tfpl.IndependentNormal(event_shape))
])
vae = tfk.Model(inputs=encoder.inputs,
outputs=encoder.outputs)
negloglik = lambda x, rv_x: -rv_x.log_prob(K.cast(x, dtype='float32'))
vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
loss=negloglik)
_ = vae.fit(train_dataset,
epochs=15,
validation_data=eval_dataset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment