Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
U-Net Keras
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Input, Concatenate, MaxPool2D, Conv2DTranspose, Add
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import Callback, History
import tensorflow.keras.backend as K
from keras.objectives import mean_squared_error
import os, tarfile, shutil, pickle
from PIL import Image
from tensorflow.contrib.tpu.python.tpu import keras_support
## UNet
def create_block(input, chs):
x = input
for i in range(2):
x = Conv2D(chs, 3, padding="same")(x) # オリジナルはpaddingなしだがサイズの調整が面倒なのでPaddingを入れる
x = BatchNormalization()(x)
x = Activation("relu")(x)
return x
def create_unet(use_skip_connections, grayscale_inputs=False):
if grayscale_inputs:
input = Input((96,96,1))
else:
input = Input((96,96,3))
# Encoder
block1 = create_block(input, 64)
x = MaxPool2D(2)(block1)
block2 = create_block(x, 128)
x = MaxPool2D(2)(block2)
block3 = create_block(x, 256)
x = MaxPool2D(2)(block3)
block4 = create_block(x, 512)
# Middle
x = MaxPool2D(2)(block4)
x = create_block(x, 1024)
# Decoder
x = Conv2DTranspose(512, kernel_size=2, strides=2)(x) # TPUだとUpsamplingやK.resize_imageが使えない
if use_skip_connections: x = Concatenate()([block4, x])
x = create_block(x, 512)
x = Conv2DTranspose(256, kernel_size=2, strides=2)(x)
if use_skip_connections: x = Concatenate()([block3, x])
x = create_block(x, 256)
x = Conv2DTranspose(128, kernel_size=2, strides=2)(x)
if use_skip_connections: x = Concatenate()([block2, x])
x = create_block(x, 128)
x = Conv2DTranspose(64, kernel_size=2, strides=2)(x)
if use_skip_connections: x = Concatenate()([block1, x])
x = create_block(x, 64)
# output
x = Conv2D(3, 1)(x)
x = Activation("sigmoid")(x)
return Model(input, x)
def loss_function(y_true, y_pred):
mses = mean_squared_error(y_true, y_pred)
return K.sum(mses, axis=(1,2))
import sys, os, urllib.request, tarfile, glob
import numpy as np
class STL10:
def __init__(self, download_dir):
self.binary_dir = os.path.join(download_dir, "stl10_binary")
if not os.path.exists(download_dir):
os.mkdir(download_dir)
if not os.path.exists(self.binary_dir):
os.mkdir(self.binary_dir)
# download file
def _progress(count, block_size, total_size):
sys.stdout.write('\rDownloading %s %.2f%%' % (source_path,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
source_path = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
dest_path = os.path.join(download_dir, "stl10_binary.tar.gz")
if not os.path.exists(dest_path):
urllib.request.urlretrieve(source_path, filename=dest_path, reporthook=_progress)
# untar
with tarfile.open(dest_path, "r:gz") as tar:
tar.extractall(path=download_dir)
def get_files(self, target):
assert target in ["train", "test", "unlabeled"]
if target in ["train", "test"]:
images = self.load_images(os.path.join(self.binary_dir, target+"_X.bin"))
labels = self.load_labels(os.path.join(self.binary_dir, target+"_y.bin"))
else:
images = self.load_images(os.path.join(self.binary_dir, target+"_X.bin"))
labels = None
return images, labels
def load_images(self, image_binary):
with open(image_binary, "rb") as fp:
images = np.fromfile(fp, dtype=np.uint8)
images = images.reshape(-1, 3, 96, 96)
return np.transpose(images, (0, 3, 2, 1))
def load_labels(self, label_binary):
with open(label_binary) as fp:
labels = np.fromfile(fp, dtype=np.uint8)
return labels.reshape(-1, 1) - 1 # 1-10 -> 0-9
def numpy_to_grayscale(tensor):
# Y = 0.299 R + 0.587 G + 0.114 B
return np.expand_dims(tensor[:,:,:,0]*0.299 + tensor[:,:,:,1]*0.587 + tensor[:,:,:,2]*0.114, axis=-1)
def generator(X, batch_size, inputs_to_grayscale):
while True:
indices = np.arange(X.shape[0])
np.random.shuffle(indices)
for i in range(X.shape[0]//batch_size):
current_indices = indices[i*batch_size:((i+1)*batch_size)]
X_batch = (X[current_indices] / 255.0).astype(np.float32)
if inputs_to_grayscale:
yield numpy_to_grayscale(X_batch), X_batch
else:
yield X_batch, X_batch
### as AutoEncoder
class SamplingCallback(Callback):
def __init__(self, model, predict_images, output_dir):
self.model = model
self.predict_images = (predict_images[:128] / 255.0).astype(np.float32) #128の倍数にする
self.output_dir = output_dir
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.mkdir(output_dir)
def combine_images(self, preds):
ny, nx, nc = preds.shape[1:]
combined = np.zeros((ny*10, nx*10, nc), dtype=np.float32)
for i in range(preds.shape[0]):
col_i = i % 10
row_i = i // 10
combined[row_i*ny:(row_i+1)*ny, col_i*nx:(col_i+1)*nx, :] = preds[i,:,:,:]
combined = (combined*255.0).astype(np.uint8)
return combined
def on_train_begin(self, logs):
combined = self.combine_images(self.predict_images[:100])
with Image.fromarray(combined) as img:
img.save(f"{self.output_dir}/groundtruth.png")
def on_epoch_end(self, epoch, logs):
preds = self.model.predict(self.predict_images)[:100]
combined = self.combine_images(preds)
with Image.fromarray(combined) as img:
img.save(f"{self.output_dir}/epoch_{epoch:03}.png")
def train_as_autoencoder(use_skip_connections, use_tpu=False):
model = create_unet(use_skip_connections, grayscale_inputs=False)
# train_test
stl10 = STL10("./stl10")
X_train, y_train = stl10.get_files("train")
X_test, y_test = stl10.get_files("test")
X_all = np.concatenate((X_train, X_test), axis=0)
y_all = np.concatenate((y_train, y_test), axis=0)
model.compile(tf.train.MomentumOptimizer(1e-3, 0.9), loss=loss_function)
# to-tpu
if use_tpu:
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
# callback
cb = SamplingCallback(model, X_all[::23][:128], "./autoencoder")
hist = History()
model.fit_generator(generator(X_all, 256), steps_per_epoch=X_all.shape[0]//256,
callbacks=[cb, hist], epochs=5)
# save_result
with open("./autoencoder/history.dat", "wb") as fp:
pickle.dump(hist.history, fp)
filename = f"autoencoder_skip_{use_skip_connections}.tar.gz"
if os.path.exists(filename):
os.remove(filename)
with tarfile.open(filename, "w:gz") as tar:
tar.add("autoencoder")
### As Colorling
class ColorlingCallback(Callback):
def __init__(self, colorling_model, output_dir):
self.colorling_model = colorling_model
self.output_dir = output_dir
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.mkdir(output_dir)
cats = np.load("cats_images.npz")
self.cats_color = (cats["color"] / 255.0).astype(np.float32)
self.cats_gray = (cats["gray"] / 255.0).astype(np.float32)
def combine_images(self, preds):
ny, nx, nc = preds.shape[1:]
combined = np.zeros((ny*10, nx*10, nc), dtype=np.float32)
for i in range(preds.shape[0]):
col_i = i % 10
row_i = i // 10
combined[row_i*ny:(row_i+1)*ny, col_i*nx:(col_i+1)*nx, :] = preds[i,:,:,:]
combined = (combined*255.0).astype(np.uint8)
return combined
def on_train_begin(self, logs):
combined = self.combine_images(self.cats_color[:100])
with Image.fromarray(combined) as img:
img.save(f"{self.output_dir}/groundtruth.png")
def on_epoch_end(self, epoch, logs):
preds = self.colorling_model.predict(self.cats_gray)[:100]
combined = self.combine_images(preds)
with Image.fromarray(combined) as img:
img.save(f"{self.output_dir}/colorling_epoch_{epoch:03}.png")
def train_as_colorling(use_skip_connections, use_tpu=False):
model = create_unet(use_skip_connections, grayscale_inputs=True)
model.summary()
# train_test
stl10 = STL10("./stl10")
X_unlabeled, _ = stl10.get_files("unlabeled")
model.compile(tf.train.MomentumOptimizer(1e-3, 0.9), loss=loss_function)
# to-tpu
if use_tpu:
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
# callback
cb = ColorlingCallback(model, "./colorling")
hist = History()
model.fit_generator(generator(X_unlabeled, 256, inputs_to_grayscale=True), steps_per_epoch=X_unlabeled.shape[0]//256,
callbacks=[cb, hist], epochs=5)
# save_result
with open("./colorling/history.dat", "wb") as fp:
pickle.dump(hist.history, fp)
filename = f"colorling_skip_{use_skip_connections}.tar.gz"
if os.path.exists(filename):
os.remove(filename)
with tarfile.open(filename, "w:gz") as tar:
tar.add("colorling")
if __name__ == "__main__":
K.clear_session()
# False, Trueを変える
train_as_autoencoder(True, use_tpu=True)
# train_as_colorling(False, use_tpu=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.