Skip to content

Instantly share code, notes, and snippets.

@ker2x
Created December 14, 2021 19:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ker2x/95f67a30c1a3f4ef0fa347aa14311c03 to your computer and use it in GitHub Desktop.
Save ker2x/95f67a30c1a3f4ef0fa347aa14311c03 to your computer and use it in GitHub Desktop.
#%%
import tensorflow as tf
import pathlib
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0],True)
datadir_color = pathlib.Path("/Volumes/T7/coco/recolorize/color")
# List image in directory
def list_image(path: pathlib.Path):
imglist = path.glob("*.jpg")
return [str(img_path) for img_path in imglist]
EPOCHS = 8
LR = 0.0005
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(16, (3,3), activation='gelu', padding='same', strides=1,data_format="channels_last",input_shape=(128, 128, 3)))
model.add(tf.keras.layers.Conv2D(8, (3, 3), activation='gelu', padding='same', strides=1,data_format="channels_last"))
model.add(tf.keras.layers.Conv2D(1, kernel_size=(3, 3), activation='gelu', padding='same',data_format="channels_last"))
model.compile(loss=tf.keras.losses.MeanSquaredError(),
optimizer=tf.keras.optimizers.Adam(),
)
color_list = [str(img_path) for img_path in datadir_color.glob("0*.jpg")]
train_x = []
train_y = []
for idx in range(50):
file_color = tf.io.read_file(color_list[idx])
image_color = tf.image.decode_jpeg(file_color, channels=3)
image_color = tf.image.resize_with_crop_or_pad(image=image_color, target_height=128, target_width=128)
train_x.append(image_color)
train_y.append(image_color)
train_x = np.array(train_x)
train_y = np.array(train_y)
model.fit(train_x, train_y, batch_size=2 ,epochs=EPOCHS,callbacks=[])
model.summary()
img = train_x[0]
predictions = model.predict(img)
predictions /= 255.0
plt.figure(2)
plt.imshow(tf.squeeze(train_x[1]), cmap='gray')
plt.figure(3)
plt.imshow(tf.squeeze(predictions), cmap='gray')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment