Skip to content

Instantly share code, notes, and snippets.

@ker2x
Created December 14, 2021 19:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ker2x/78d6aef870173c62f207f40d697bb6d5 to your computer and use it in GitHub Desktop.
Save ker2x/78d6aef870173c62f207f40d697bb6d5 to your computer and use it in GitHub Desktop.
#%%
import tensorflow as tf
import pathlib
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Optionally set memory groth to True
# -----------------------------------
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0],True)
#%%
datadir_color = pathlib.Path("/Volumes/T7/coco/recolorize/color")
datadir_bw = pathlib.Path("/Volumes/T7/coco/recolorize/bw")
#%%
# List image in directory
def list_image(path: pathlib.Path):
imglist = path.glob("*.jpg")
return [str(img_path) for img_path in imglist]
#%%
class ImageSequence(tf.keras.utils.Sequence):
def __init__(self, color_path, bw_path):
self.color_list = [str(img_path) for img_path in color_path.glob("0*.jpg")]
self.bw_list = [str(img_path) for img_path in bw_path.glob("0*.jpg")]
def __len__(self):
#return len(self.color_list)
return 50
def __getitem__(self, idx):
# load
file_bw = tf.io.read_file(self.bw_list[idx])
file_color = tf.io.read_file(self.color_list[idx])
# decode
image_bw = tf.image.decode_jpeg(file_bw, channels=1)
image_color = tf.image.decode_jpeg(file_color, channels=3)
# optional resize
image_bw = tf.image.resize_with_crop_or_pad(image=image_bw, target_height=128, target_width=128)
image_color = tf.image.resize_with_crop_or_pad(image=image_color, target_height=128, target_width=128)
image_bw2 = tf.cast(tf.broadcast_to(image_bw, [128, 128, 1]), tf.float32)
return image_bw2, image_bw2
class genImage(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
plt.figure(4)
plt.clf()
bw_list = [str(img_path) for img_path in datadir_bw.glob("0*.jpg")]
file_bw = tf.io.read_file(bw_list[3000])
image_bw = tf.image.decode_jpeg(file_bw, channels=1)
predictions = model.predict(image_bw)
predictions /= 255.0 # tf.cast(predictions_f, tf.int8)
plt.figure(2)
plt.imshow(predictions)
plt.show()
EPOCHS = 32
LR = 0.0005
HIDDENLAYERS = 4
LAYERWIDTH = 9
#tf.keras.Input(shape=(1,))
#model.add(tf.keras.layers.Flatten((255,255,3)))
#model.add(tf.keras.layers.Dense(512, activation="swish"))
#for _ in range(HIDDENLAYERS):
# model.add(tf.keras.layers.Dense(LAYERWIDTH, activation="swish"))
class Recolor(tf.keras.Model):
def __init__(self):
super(Recolor, self).__init__()
self.encoder = tf.keras.Sequential([
tf.keras.layers.Input(shape=(64, 64, 1)),
tf.keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same', strides=1),
tf.keras.layers.Conv2D(8, (3, 3), activation='relu', padding='same', strides=1)])
self.decoder = tf.keras.Sequential([
tf.keras.layers.Conv2DTranspose(8, kernel_size=3, strides=1, activation='relu', padding='same'),
tf.keras.layers.Conv2DTranspose(16, kernel_size=3, strides=1, activation='relu', padding='same'),
tf.keras.layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')])
def call(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
#model = Recolor()
model = tf.keras.Sequential()
#model.add(tf.keras.layers.InputLayer(input_shape=(128,128,3), dtype=tf.float32))
model.add(tf.keras.layers.Conv2D(16, (3,3), activation='gelu', padding='same', strides=1,data_format="channels_last",input_shape=(128, 128, 1)))
model.add(tf.keras.layers.Conv2D(8, (3, 3), activation='gelu', padding='same', strides=1,data_format="channels_last"))
#model.add(tf.keras.layers.Conv2DTranspose(8, kernel_size=3, strides=1, activation='relu', padding='same'))
#model.add(tf.keras.layers.Conv2DTranspose(16, kernel_size=16, strides=1, activation='relu', padding='same'))
model.add(tf.keras.layers.Conv2D(1, kernel_size=(3, 3), activation='gelu', padding='same',data_format="channels_last"))
#model.add(tf.keras.layers.Dense(3, activation=None))
model.compile(loss=tf.keras.losses.MeanSquaredError(),
optimizer=tf.keras.optimizers.Adam(),#learning_rate=LR),
# optimizer=tf.keras.optimizers.Adadelta(learning_rate=1.0),
# metrics=["accuracy", "mae", "mse"])
)
sequence = ImageSequence(datadir_color, datadir_bw)
#history = model.fit(train_x, train_y,epochs=EPOCHS,callbacks=[])
#color_list = [str(img_path) for img_path in color_path.glob("0*.jpg")]
#bw_list = [str(img_path) for img_path in bw_path.glob("0*.jpg")]
history = model.fit(sequence,epochs=EPOCHS,callbacks=[])
model.summary()
np.set_printoptions(precision=3, suppress=True)
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
hist.tail()
plt.figure(1)
def plot_loss(history):
# plt.plot(history.history['val_loss'], label='validation loss')
plt.plot(history.history['loss'], label='training loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.ylim((0,2))
plt.legend()
plt.grid(True)
plt.show()
plot_loss(history)
bw_list = [str(img_path) for img_path in datadir_bw.glob("0*.jpg")]
file_bw = tf.io.read_file(bw_list[301])
image_bw = tf.image.decode_jpeg(file_bw, channels=1)
image_bw = tf.image.resize_with_crop_or_pad(image=image_bw, target_height=128, target_width=128)
image_bw2 = tf.cast(tf.broadcast_to(image_bw, [128, 128, 1]), tf.float32)
#print(image_bw2)
predictions = model.predict(image_bw2)
predictions /= 255.0 #tf.cast(predictions_f, tf.int8)
plt.figure(2)
plt.imshow(tf.squeeze(image_bw), cmap='gray')
plt.figure(3)
plt.imshow(tf.squeeze(predictions), cmap='gray')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment