Skip to content

Instantly share code, notes, and snippets.

@ground0state
Created August 22, 2019 12:49
Show Gist options
  • Save ground0state/402e1e33c453d3649bec05d407d90d15 to your computer and use it in GitHub Desktop.
Save ground0state/402e1e33c453d3649bec05d407d90d15 to your computer and use it in GitHub Desktop.
import os
import glob
import math
import random
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tensorflow.python import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras.layers import *
from tensorflow.python.keras.preprocessing.image import load_img, img_to_array, array_to_img, ImageDataGenerator
data_path = ""
data_lists = glob.glob(data_path+'*.jpg')
val_n_sample = math.floor(len(data_lists)*0.1)
test_n_sample = math.floor(len(data_lists)*0.1)
train_n_sample = len(data_lists) - val_n_sample - test_n_sample
val_lists = data_lists[: val_n_sample]
test_lists = data_lists[val_n_sample: val_n_sample+test_n_sample]
train_lists = data_lists[val_n_sample+test_n_sample: val_n_sample+test_n_sample+train_n_sample]
def rgb2lab(rgb):
assert rgb.dtype == 'uint8'
return cv2.cvtColor(rgb, cv2.COLOR_RGB2Lab)
def lab2rgb(lab):
assert lab.dtype == 'uint8'
return cv2.cvtColor(lab, cv2.COLOR_Lab2RGB)
img_size = 224
def get_lab_from_data_list(data_list):
x_lab = []
for f in data_list:
rgb = img_to_array(load_img(f, target_size=(img_size, img_size))).astype(np.uint8)
lab = rgb2lab(rgb)
x_lab.append(lab)
return np.stack(x_lab)
inputs = Input(shape=(224, 224, 1))
x = Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), activation='relu', padding='same')(inputs)
x = Conv2D(filters=64, kernel_size=(3, 3), strides=(2, 2), activation='relu', padding='same')(x)
x = Conv2D(filters=128, kernel_size=(3, 3), strides=(2, 2), activation='relu', padding='same')(x)
x = Conv2D(filters=256, kernel_size=(3, 3), strides=(2, 2), activation='relu', padding='same')(x)
x = Conv2DTranspose(filters=128, kernel_size=(3, 3), strides=(2, 2), activation='relu', padding='same')(x)
x = Conv2DTranspose(filters=64, kernel_size=(3, 3), strides=(2, 2), activation='relu', padding='same')(x)
x = Conv2DTranspose(filters=32, kernel_size=(3, 3), strides=(2, 2), activation='relu', padding='same')(x)
outputs = Conv2DTranspose(filters=2, kernel_size=(1, 1), strides=(1, 1), activation='relu', padding='same')(x)
autoencoder = Model(inputs, outputs)
autoencoder.compile(optimizer='adam', loss='mse')
print(autoencoder.summary())
def generator_with_preprocessing(data_list, batch_size, shuffle=False):
while True:
if shuffle:
np.random.shuffle(data_list)
for i in range(0, len(data_list), batch_size):
batch_list = data_list[i: i+batch_size]
batch_lab = get_lab_from_data_list(batch_list)
batch_l = batch_lab[:, :, :, 0:1]
batch_ab = batch_lab[:, :, :, 1:]
yield(batch_l, batch_ab)
batch_size = 128
train_gen = generator_with_preprocessing(train_lists, batch_size, shuffle=True)
val_gen = generator_with_preprocessing(val_lists, batch_size)
test_gen = generator_with_preprocessing(test_lists, batch_size)
train_steps = math.ceil(len(train_lists)/batch_size)
val_steps = math.ceil(len(val_lists)/batch_size)
test_steps = math.ceil(len(test_lists)/batch_size)
epochs = 10
autoencoder.fit_generator(generator=train_gen, steps_per_epoch=train_steps, epochs=epochs, validation_data=val_gen, validation_steps=val_steps)
preds = autoencoder.predict_generator(test_gen, steps=test_steps, verbose=0)
x_test = []
y_test = []
for i, (l, ab) in enumerate(generator_with_preprocessing(test_lists, batch_size)):
x_test.append(l)
y_test.append(ab)
if i == (test_steps - 1):
break
x_test = np.vstack(x_test)
y_test = np.vstack(y_test)
test_preds_lab = np.concatenate((x_test, preds), 3).astype(np.uint8)
test_preds_rgb = []
for i in range(test_preds_lab.shape[0]):
preds_rgb = lab2rgb(test_preds_lab[i, :, :, :])
test_preds_rgb.append(preds_rgb)
test_preds_rgb = np.stack(test_preds_rgb)
from IPython.display import display_png
from PIL import Image, ImageOps
for i in range(test_preds_rgb.shape[0]):
gray_image = ImageOps.grayscale(array_to_img(test_preds_rgb[i]))
display_png(gray_image)
display_png(array_to_img(test_preds_rgb[i]))
print('-'*25)
if i == 20:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment