Created
December 14, 2021 19:47
-
-
Save ker2x/29f080c848e52c9b0a7f3767baed0479 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import tensorflow as tf | |
import pathlib | |
import os | |
import time | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
physical_devices = tf.config.list_physical_devices('GPU') | |
tf.config.experimental.set_memory_growth(physical_devices[0],True) | |
datadir_color = pathlib.Path("/Volumes/T7/coco/recolorize/color") | |
# List image in directory | |
def list_image(path: pathlib.Path): | |
imglist = path.glob("*.jpg") | |
return [str(img_path) for img_path in imglist] | |
EPOCHS = 8 | |
LR = 0.0005 | |
model = tf.keras.Sequential() | |
model.add(tf.keras.layers.Conv2D(16, (3,3), activation='gelu', padding='same', strides=1,data_format="channels_last",input_shape=(128, 128, 3))) | |
model.add(tf.keras.layers.Conv2D(8, (3, 3), activation='gelu', padding='same', strides=1,data_format="channels_last")) | |
model.add(tf.keras.layers.Conv2D(1, kernel_size=(3, 3), activation='gelu', padding='same',data_format="channels_last")) | |
model.compile(loss=tf.keras.losses.MeanSquaredError(), | |
optimizer=tf.keras.optimizers.Adam(), | |
) | |
color_list = [str(img_path) for img_path in datadir_color.glob("0*.jpg")] | |
train_x = [] | |
train_y = [] | |
for idx in range(50): | |
file_color = tf.io.read_file(color_list[idx]) | |
image_color = tf.image.decode_jpeg(file_color, channels=3) | |
image_color = tf.image.resize_with_crop_or_pad(image=image_color, target_height=128, target_width=128) | |
train_x.append(image_color) | |
train_y.append(image_color) | |
train_x = np.array(train_x) | |
train_y = np.array(train_y) | |
history = model.fit(train_x, train_y, batch_size=2 ,epochs=EPOCHS,callbacks=[]) | |
model.summary() | |
np.set_printoptions(precision=3, suppress=True) | |
hist = pd.DataFrame(history.history) | |
hist['epoch'] = history.epoch | |
hist.tail() | |
plt.figure(1) | |
def plot_loss(history): | |
plt.plot(history.history['loss'], label='training loss') | |
plt.xlabel('Epoch') | |
plt.ylabel('Loss') | |
plt.ylim((0,2)) | |
plt.legend() | |
plt.grid(True) | |
plt.show() | |
plot_loss(history) | |
img = train_x[0] | |
predictions = model.predict(img) | |
predictions /= 255.0 | |
plt.figure(2) | |
plt.imshow(tf.squeeze(train_x[1]), cmap='gray') | |
plt.figure(3) | |
plt.imshow(tf.squeeze(predictions), cmap='gray') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment