Skip to content

Instantly share code, notes, and snippets.

@BlueskyFR
Created October 24, 2021 00:07
Show Gist options
  • Save BlueskyFR/94337cf7fe8568ae43969b47d6884432 to your computer and use it in GitHub Desktop.
Save BlueskyFR/94337cf7fe8568ae43969b47d6884432 to your computer and use it in GitHub Desktop.
# To add a new cell, type '# %%'
# To add a new markdown cell, type '# %% [markdown]'
# %%
import time
global_start_time = time.time()
# %%
import tensorflow as tf
print(f"✨ Using TensorFlow {tf.__version__}!")
for device in tf.config.experimental.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
from tensorflow.keras.layers import Resizing, Rescaling, Dense, BatchNormalization, LeakyReLU, Conv2DTranspose, Reshape, Conv2D, Dropout, Flatten
import matplotlib.pyplot as plt
from IPython import display
from pathlib import Path
import imageio
import glob
def plot(img):
plt.imshow((img + 1) / 2)
# %% [markdown]
# # Load the data
# %%
DATA_DIR = Path("./cats/")
IMG_SIZE = (128, 128)
# List images
dataset = tf.data.Dataset.list_files((DATA_DIR / "**/*.jpg").as_posix())
# Load images
dataset = dataset.map(
lambda file: tf.io.decode_jpeg(
tf.io.read_file(file)
)
)
# Preprocessing
preprocess = tf.keras.Sequential([
Resizing(*IMG_SIZE),
Rescaling(scale=1. / 127.5, offset=-1) # Normalize from [0, 255] to [-1, 1]
])
dataset = dataset.map(lambda img: preprocess(img))
print(f"Loaded dataset of {len(dataset)} cats!")
for i in dataset.take(1):
plot(i) # Rescale to [0, 1] for imshow
# %%
# (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
# train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype("float32")
# train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
# %% [markdown]
# ## Batch, cache and shuffle the dataset
# %%
BUFFER_SIZE = len(dataset)
BATCH_SIZE = 256
# Batch and shuffle the data
# train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
dataset = dataset.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
print(f"Each epoch will contain {len(dataset)} batches!")
# %% [markdown]
# # Create the models
# ## The Generator
# %%
generator = tf.keras.Sequential([
Dense(8 * 8 * 256, use_bias=False, input_shape=(100,)),
BatchNormalization(),
LeakyReLU(),
Reshape((8, 8, 256)),
# Transpose = Deconv = Upsampling
Conv2DTranspose(filters=1024, kernel_size=5, strides=1, padding="same", use_bias=False),
# Output shape: (None, 8, 8, 1024); None is the batch size
BatchNormalization(),
LeakyReLU(),
Conv2DTranspose(filters=512, kernel_size=5, strides=2, padding="same", use_bias=False),
# Output shape: (None, 16, 16, 512)
BatchNormalization(),
LeakyReLU(),
Conv2DTranspose(filters=256, kernel_size=5, strides=1, padding="same", use_bias=False),
# Output shape: (None, 16, 16, 256)
BatchNormalization(),
LeakyReLU(),
Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding="same", use_bias=False),
# Output shape: (None, 32, 32, 128)
BatchNormalization(),
LeakyReLU(),
Conv2DTranspose(filters=64, kernel_size=5, strides=2, padding="same", use_bias=False),
# Output shape: (None, 64, 64, 64)
BatchNormalization(),
LeakyReLU(),
Conv2DTranspose(filters=3, kernel_size=5, strides=2, padding="same", use_bias=False, activation="tanh")
# Output shape: (None, 128, 128, 3)
])
# %% [markdown]
# ### Test the Generator
# %%
noise = tf.random.normal((1, 100))
print(f"Noise shape: {noise.shape}")
generated_image = generator(noise, training=False) # training=False prevents callbacks from being called + runs in inference mode (batchnorm)
print(generated_image.shape)
plot(generated_image[0])#, cmap="gray")
# %% [markdown]
# ## The Discriminator
#
# Classifies the images as real or fake. Positive is real, negative is fake.
# %%
discriminator = tf.keras.Sequential([
Conv2D(filters=64, kernel_size=5, strides=2, padding="same", input_shape=(*IMG_SIZE, 3)),
# Output shape: (None, 64, 64, 64)
LeakyReLU(),
Dropout(0.3),
Conv2D(filters=128, kernel_size=5, strides=2, padding="same"),
# Output shape: (None, 32, 32, 128)
LeakyReLU(),
Dropout(0.3),
Conv2D(filters=256, kernel_size=5, strides=2, padding="same"),
# Output shape: (None, 16, 16, 256)
LeakyReLU(),
Dropout(0.3),
Conv2D(filters=512, kernel_size=5, strides=1, padding="same"),
# Output shape: (None, 16, 16, 512)
LeakyReLU(),
Dropout(0.3),
Conv2D(filters=1024, kernel_size=5, strides=1, padding="same"),
# Output shape: (None, 8, 8, 1024)
LeakyReLU(),
Dropout(0.3),
Flatten(),
Dense(1)
])
# %% [markdown]
# ### Test the Discriminator on the previously generated image
# %%
decision = discriminator(generated_image, training=False)
print(decision)
# %% [markdown]
# Because of the random biases initialization, the output is close to 0.
#
# # Loss and optimizers
# %%
# Helper function to compute the cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
# %% [markdown]
# ## Discriminator loss
#
# Each time, the discriminator will receive batches of both real and a fake images.
# The output for a batch of real images should be an array of 1s, and an array of 0s for a fake one.
# %%
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
return real_loss + fake_loss
# %% [markdown]
# ## Generator loss
#
# The generator loss quantifies how well it was able to trick the discriminator. If the generator is performing well, the discriminator will classify the fake images (i.e. the generated ones) as real, so as an array on 1s.
# %%
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
# %% [markdown]
# ## Optimizers
# %%
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
# %% [markdown]
# # Define the training loop
# %%
EPOCHS = 5#1000
noise_dim = 100
num_examples_to_generate = 16
# A seed periodically used to generate a nice gif and visualize progression
demo_gif_seed = tf.random.normal((num_examples_to_generate, noise_dim))
# We use tf.function so that the function is "compiled" through the TensorFlow graph
@tf.function
def train_step(real_images):
noise = tf.random.normal((BATCH_SIZE, noise_dim))
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
fake_images = generator(noise, training=True)
real_output = discriminator(real_images, training=True)
fake_output = discriminator(fake_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
# %%
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# After iterating through the entire dataset, save a snapshot
# of the current generator state to generate a GIF later
#display.clear_output(wait=True)
#generate_and_save_images(generator, epoch + 1, demo_gif_seed)
print(f"Time for epoch {epoch + 1} is {time.time() - start}")
#display.clear_output(wait=True)
#generate_and_save_images(generator, epochs, demo_gif_seed)
# %%
def generate_and_save_images(model, epoch, test_input):
predictions = model(test_input, training=False)
plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
# plt.imshow(predictions[i] * 127.5 + 127.5)#, cmap="gray")
plot(predictions[i])
plt.axis("off")
plt.savefig(f"image_at_epoch_{epoch:03d}.png")
plt.show()
# %% [markdown]
# # Train the model
# %%
train(dataset, EPOCHS)
# %% [markdown]
# # Create a GIF
# %%
# gif_file = "gan.gif"
# with imageio.get_writer(gif_file, mode='I') as writer:
# for filename in sorted(glob.glob("image*.png")):
# image = imageio.imread(filename)
# writer.append_data(image)
# image = imageio.imread(gif_file)
# writer.append_data(image)
print(f"Total run time (in seconds): {time.time() - global_start_time}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment