Created
December 14, 2021 19:08
-
-
Save ker2x/78d6aef870173c62f207f40d697bb6d5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import tensorflow as tf | |
import pathlib | |
import os | |
import time | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
# Optionally set memory groth to True | |
# ----------------------------------- | |
physical_devices = tf.config.list_physical_devices('GPU') | |
tf.config.experimental.set_memory_growth(physical_devices[0],True) | |
#%% | |
datadir_color = pathlib.Path("/Volumes/T7/coco/recolorize/color") | |
datadir_bw = pathlib.Path("/Volumes/T7/coco/recolorize/bw") | |
#%% | |
# List image in directory | |
def list_image(path: pathlib.Path): | |
imglist = path.glob("*.jpg") | |
return [str(img_path) for img_path in imglist] | |
#%% | |
class ImageSequence(tf.keras.utils.Sequence): | |
def __init__(self, color_path, bw_path): | |
self.color_list = [str(img_path) for img_path in color_path.glob("0*.jpg")] | |
self.bw_list = [str(img_path) for img_path in bw_path.glob("0*.jpg")] | |
def __len__(self): | |
#return len(self.color_list) | |
return 50 | |
def __getitem__(self, idx): | |
# load | |
file_bw = tf.io.read_file(self.bw_list[idx]) | |
file_color = tf.io.read_file(self.color_list[idx]) | |
# decode | |
image_bw = tf.image.decode_jpeg(file_bw, channels=1) | |
image_color = tf.image.decode_jpeg(file_color, channels=3) | |
# optional resize | |
image_bw = tf.image.resize_with_crop_or_pad(image=image_bw, target_height=128, target_width=128) | |
image_color = tf.image.resize_with_crop_or_pad(image=image_color, target_height=128, target_width=128) | |
image_bw2 = tf.cast(tf.broadcast_to(image_bw, [128, 128, 1]), tf.float32) | |
return image_bw2, image_bw2 | |
class genImage(tf.keras.callbacks.Callback): | |
def on_epoch_end(self, epoch, logs=None): | |
plt.figure(4) | |
plt.clf() | |
bw_list = [str(img_path) for img_path in datadir_bw.glob("0*.jpg")] | |
file_bw = tf.io.read_file(bw_list[3000]) | |
image_bw = tf.image.decode_jpeg(file_bw, channels=1) | |
predictions = model.predict(image_bw) | |
predictions /= 255.0 # tf.cast(predictions_f, tf.int8) | |
plt.figure(2) | |
plt.imshow(predictions) | |
plt.show() | |
EPOCHS = 32 | |
LR = 0.0005 | |
HIDDENLAYERS = 4 | |
LAYERWIDTH = 9 | |
#tf.keras.Input(shape=(1,)) | |
#model.add(tf.keras.layers.Flatten((255,255,3))) | |
#model.add(tf.keras.layers.Dense(512, activation="swish")) | |
#for _ in range(HIDDENLAYERS): | |
# model.add(tf.keras.layers.Dense(LAYERWIDTH, activation="swish")) | |
class Recolor(tf.keras.Model): | |
def __init__(self): | |
super(Recolor, self).__init__() | |
self.encoder = tf.keras.Sequential([ | |
tf.keras.layers.Input(shape=(64, 64, 1)), | |
tf.keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same', strides=1), | |
tf.keras.layers.Conv2D(8, (3, 3), activation='relu', padding='same', strides=1)]) | |
self.decoder = tf.keras.Sequential([ | |
tf.keras.layers.Conv2DTranspose(8, kernel_size=3, strides=1, activation='relu', padding='same'), | |
tf.keras.layers.Conv2DTranspose(16, kernel_size=3, strides=1, activation='relu', padding='same'), | |
tf.keras.layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')]) | |
def call(self, x): | |
encoded = self.encoder(x) | |
decoded = self.decoder(encoded) | |
return decoded | |
#model = Recolor() | |
model = tf.keras.Sequential() | |
#model.add(tf.keras.layers.InputLayer(input_shape=(128,128,3), dtype=tf.float32)) | |
model.add(tf.keras.layers.Conv2D(16, (3,3), activation='gelu', padding='same', strides=1,data_format="channels_last",input_shape=(128, 128, 1))) | |
model.add(tf.keras.layers.Conv2D(8, (3, 3), activation='gelu', padding='same', strides=1,data_format="channels_last")) | |
#model.add(tf.keras.layers.Conv2DTranspose(8, kernel_size=3, strides=1, activation='relu', padding='same')) | |
#model.add(tf.keras.layers.Conv2DTranspose(16, kernel_size=16, strides=1, activation='relu', padding='same')) | |
model.add(tf.keras.layers.Conv2D(1, kernel_size=(3, 3), activation='gelu', padding='same',data_format="channels_last")) | |
#model.add(tf.keras.layers.Dense(3, activation=None)) | |
model.compile(loss=tf.keras.losses.MeanSquaredError(), | |
optimizer=tf.keras.optimizers.Adam(),#learning_rate=LR), | |
# optimizer=tf.keras.optimizers.Adadelta(learning_rate=1.0), | |
# metrics=["accuracy", "mae", "mse"]) | |
) | |
sequence = ImageSequence(datadir_color, datadir_bw) | |
#history = model.fit(train_x, train_y,epochs=EPOCHS,callbacks=[]) | |
#color_list = [str(img_path) for img_path in color_path.glob("0*.jpg")] | |
#bw_list = [str(img_path) for img_path in bw_path.glob("0*.jpg")] | |
history = model.fit(sequence,epochs=EPOCHS,callbacks=[]) | |
model.summary() | |
np.set_printoptions(precision=3, suppress=True) | |
hist = pd.DataFrame(history.history) | |
hist['epoch'] = history.epoch | |
hist.tail() | |
plt.figure(1) | |
def plot_loss(history): | |
# plt.plot(history.history['val_loss'], label='validation loss') | |
plt.plot(history.history['loss'], label='training loss') | |
plt.xlabel('Epoch') | |
plt.ylabel('Loss') | |
plt.ylim((0,2)) | |
plt.legend() | |
plt.grid(True) | |
plt.show() | |
plot_loss(history) | |
bw_list = [str(img_path) for img_path in datadir_bw.glob("0*.jpg")] | |
file_bw = tf.io.read_file(bw_list[301]) | |
image_bw = tf.image.decode_jpeg(file_bw, channels=1) | |
image_bw = tf.image.resize_with_crop_or_pad(image=image_bw, target_height=128, target_width=128) | |
image_bw2 = tf.cast(tf.broadcast_to(image_bw, [128, 128, 1]), tf.float32) | |
#print(image_bw2) | |
predictions = model.predict(image_bw2) | |
predictions /= 255.0 #tf.cast(predictions_f, tf.int8) | |
plt.figure(2) | |
plt.imshow(tf.squeeze(image_bw), cmap='gray') | |
plt.figure(3) | |
plt.imshow(tf.squeeze(predictions), cmap='gray') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment