Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created November 13, 2018 11:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/1825e5d0f3e5b05dbb612f84ff394045 to your computer and use it in GitHub Desktop.
Save koshian2/1825e5d0f3e5b05dbb612f84ff394045 to your computer and use it in GitHub Desktop.
latent features in u-net
from keras.layers import Conv2D, BatchNormalization, Activation, Input, Concatenate, MaxPool2D, Conv2DTranspose, Add, Flatten, Dense, Reshape
from keras.models import Model
from keras.callbacks import ModelCheckpoint
from keras.optimizers import SGD
import keras.backend as K
from keras.objectives import mean_squared_error
from keras.datasets import mnist
import os, tarfile, shutil, pickle
import numpy as np
## UNet
def create_block(input, chs, latent_flag=False):
x = input
for i in range(2):
x = Conv2D(chs, 3, padding="same")(x) # オリジナルはpaddingなしだがサイズの調整が面倒なのでPaddingを入れる
x = BatchNormalization()(x)
if i == 1 and latent_flag:
x = Activation("tanh", name="latent_features")(x)
else:
x = Activation("relu")(x)
return x
def create_unet(use_skip_connections, bottle_neck=False):
input = Input((28, 28, 1))
# Encoder
block1 = create_block(input, 4)
x = MaxPool2D(2)(block1)
block2 = create_block(x, 8)
x = MaxPool2D(2)(block2)
# Middle
if not bottle_neck:
x = create_block(x, 16, True)
else:
x = create_block(x, 16)
x = Flatten()(x)
x = Dense(16, activation="tanh", name="latent_features")(x)
x = Dense(784, activation="relu")(x)
x = Reshape((7, 7, 16))(x)
# Decoder
x = Conv2DTranspose(8, kernel_size=2, strides=2)(x)
if use_skip_connections: x = Concatenate()([block2, x])
x = create_block(x, 8)
x = Conv2DTranspose(4, kernel_size=2, strides=2)(x)
if use_skip_connections: x = Concatenate()([block1, x])
x = create_block(x, 4)
# output
x = Conv2D(1, 1)(x)
x = Activation("sigmoid")(x)
return Model(input, x)
def loss_function(y_true, y_pred):
mses = mean_squared_error(y_true, y_pred)
return K.sum(mses, axis=(1,2))
def train_autoencoder(use_skip_connections, bottle_neck):
model = create_unet(use_skip_connections, bottle_neck)
model.compile(SGD(1e-3, 0.9), loss=loss_function)
model.summary()
(X, _), (_, _) = mnist.load_data()
X = (X / 255.0).reshape(-1, 28, 28, 1)
dest_dir = f"mnist_skip_{use_skip_connections}_bottleneck_{bottle_neck}"
cp = ModelCheckpoint(dest_dir+f"/{use_skip_connections}_{bottle_neck}_weights.hdf5", monitor="loss", verbose=1, save_weights_only=True)
hist_path = f"{dest_dir}/history_skip_{use_skip_connections}_bottleneck_{bottle_neck}.dat"
if not os.path.exists(dest_dir):
os.mkdir(dest_dir)
history = model.fit(X, X, batch_size=512, epochs=1, callbacks=[cp]).history
with open(hist_path, "wb") as fp:
pickle.dump(history, fp)
with tarfile.open(dest_dir+".tar.gz", mode="w:gz") as tar:
tar.add(dest_dir)
def get_latent_features(use_skip_connections, bottle_neck):
model = create_unet(use_skip_connections, bottle_neck)
weights_dir = f"mnist_skip_{use_skip_connections}_bottleneck_{bottle_neck}"
weights_path = weights_dir+f"/{use_skip_connections}_{bottle_neck}_weights.hdf5"
outname = f"latent_skip_{use_skip_connections}_bottleneck_{bottle_neck}"
model.load_weights(weights_path)
(X, y), (_, _) = mnist.load_data()
X = (X / 255.0).reshape(-1, 28, 28, 1)
# latent-model
x = model.get_layer("latent_features").output
if not bottle_neck:
x = Flatten()(x) # Global Average Poolingは明らかに悪い(純度が3割ぐらい落ちるのでFlatternにする)
latent_model = Model(model.input, x)
latent_model.summary()
latent_value = latent_model.predict(X, batch_size=512)
np.savez_compressed(outname, latent=latent_value, ground_truth=y)
if __name__ == "__main__":
train_autoencoder(False, False)
#get_latent_features(False, False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment