Created
December 14, 2021 19:48
-
-
Save ker2x/95f67a30c1a3f4ef0fa347aa14311c03 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import tensorflow as tf | |
import pathlib | |
import os | |
import time | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
physical_devices = tf.config.list_physical_devices('GPU') | |
tf.config.experimental.set_memory_growth(physical_devices[0],True) | |
datadir_color = pathlib.Path("/Volumes/T7/coco/recolorize/color") | |
# List image in directory | |
def list_image(path: pathlib.Path): | |
imglist = path.glob("*.jpg") | |
return [str(img_path) for img_path in imglist] | |
EPOCHS = 8 | |
LR = 0.0005 | |
model = tf.keras.Sequential() | |
model.add(tf.keras.layers.Conv2D(16, (3,3), activation='gelu', padding='same', strides=1,data_format="channels_last",input_shape=(128, 128, 3))) | |
model.add(tf.keras.layers.Conv2D(8, (3, 3), activation='gelu', padding='same', strides=1,data_format="channels_last")) | |
model.add(tf.keras.layers.Conv2D(1, kernel_size=(3, 3), activation='gelu', padding='same',data_format="channels_last")) | |
model.compile(loss=tf.keras.losses.MeanSquaredError(), | |
optimizer=tf.keras.optimizers.Adam(), | |
) | |
color_list = [str(img_path) for img_path in datadir_color.glob("0*.jpg")] | |
train_x = [] | |
train_y = [] | |
for idx in range(50): | |
file_color = tf.io.read_file(color_list[idx]) | |
image_color = tf.image.decode_jpeg(file_color, channels=3) | |
image_color = tf.image.resize_with_crop_or_pad(image=image_color, target_height=128, target_width=128) | |
train_x.append(image_color) | |
train_y.append(image_color) | |
train_x = np.array(train_x) | |
train_y = np.array(train_y) | |
model.fit(train_x, train_y, batch_size=2 ,epochs=EPOCHS,callbacks=[]) | |
model.summary() | |
img = train_x[0] | |
predictions = model.predict(img) | |
predictions /= 255.0 | |
plt.figure(2) | |
plt.imshow(tf.squeeze(train_x[1]), cmap='gray') | |
plt.figure(3) | |
plt.imshow(tf.squeeze(predictions), cmap='gray') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment