Created
November 13, 2018 11:25
-
-
Save koshian2/1825e5d0f3e5b05dbb612f84ff394045 to your computer and use it in GitHub Desktop.
latent features in u-net
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Conv2D, BatchNormalization, Activation, Input, Concatenate, MaxPool2D, Conv2DTranspose, Add, Flatten, Dense, Reshape | |
from keras.models import Model | |
from keras.callbacks import ModelCheckpoint | |
from keras.optimizers import SGD | |
import keras.backend as K | |
from keras.objectives import mean_squared_error | |
from keras.datasets import mnist | |
import os, tarfile, shutil, pickle | |
import numpy as np | |
## UNet | |
def create_block(input, chs, latent_flag=False): | |
x = input | |
for i in range(2): | |
x = Conv2D(chs, 3, padding="same")(x) # オリジナルはpaddingなしだがサイズの調整が面倒なのでPaddingを入れる | |
x = BatchNormalization()(x) | |
if i == 1 and latent_flag: | |
x = Activation("tanh", name="latent_features")(x) | |
else: | |
x = Activation("relu")(x) | |
return x | |
def create_unet(use_skip_connections, bottle_neck=False): | |
input = Input((28, 28, 1)) | |
# Encoder | |
block1 = create_block(input, 4) | |
x = MaxPool2D(2)(block1) | |
block2 = create_block(x, 8) | |
x = MaxPool2D(2)(block2) | |
# Middle | |
if not bottle_neck: | |
x = create_block(x, 16, True) | |
else: | |
x = create_block(x, 16) | |
x = Flatten()(x) | |
x = Dense(16, activation="tanh", name="latent_features")(x) | |
x = Dense(784, activation="relu")(x) | |
x = Reshape((7, 7, 16))(x) | |
# Decoder | |
x = Conv2DTranspose(8, kernel_size=2, strides=2)(x) | |
if use_skip_connections: x = Concatenate()([block2, x]) | |
x = create_block(x, 8) | |
x = Conv2DTranspose(4, kernel_size=2, strides=2)(x) | |
if use_skip_connections: x = Concatenate()([block1, x]) | |
x = create_block(x, 4) | |
# output | |
x = Conv2D(1, 1)(x) | |
x = Activation("sigmoid")(x) | |
return Model(input, x) | |
def loss_function(y_true, y_pred): | |
mses = mean_squared_error(y_true, y_pred) | |
return K.sum(mses, axis=(1,2)) | |
def train_autoencoder(use_skip_connections, bottle_neck): | |
model = create_unet(use_skip_connections, bottle_neck) | |
model.compile(SGD(1e-3, 0.9), loss=loss_function) | |
model.summary() | |
(X, _), (_, _) = mnist.load_data() | |
X = (X / 255.0).reshape(-1, 28, 28, 1) | |
dest_dir = f"mnist_skip_{use_skip_connections}_bottleneck_{bottle_neck}" | |
cp = ModelCheckpoint(dest_dir+f"/{use_skip_connections}_{bottle_neck}_weights.hdf5", monitor="loss", verbose=1, save_weights_only=True) | |
hist_path = f"{dest_dir}/history_skip_{use_skip_connections}_bottleneck_{bottle_neck}.dat" | |
if not os.path.exists(dest_dir): | |
os.mkdir(dest_dir) | |
history = model.fit(X, X, batch_size=512, epochs=1, callbacks=[cp]).history | |
with open(hist_path, "wb") as fp: | |
pickle.dump(history, fp) | |
with tarfile.open(dest_dir+".tar.gz", mode="w:gz") as tar: | |
tar.add(dest_dir) | |
def get_latent_features(use_skip_connections, bottle_neck): | |
model = create_unet(use_skip_connections, bottle_neck) | |
weights_dir = f"mnist_skip_{use_skip_connections}_bottleneck_{bottle_neck}" | |
weights_path = weights_dir+f"/{use_skip_connections}_{bottle_neck}_weights.hdf5" | |
outname = f"latent_skip_{use_skip_connections}_bottleneck_{bottle_neck}" | |
model.load_weights(weights_path) | |
(X, y), (_, _) = mnist.load_data() | |
X = (X / 255.0).reshape(-1, 28, 28, 1) | |
# latent-model | |
x = model.get_layer("latent_features").output | |
if not bottle_neck: | |
x = Flatten()(x) # Global Average Poolingは明らかに悪い(純度が3割ぐらい落ちるのでFlatternにする) | |
latent_model = Model(model.input, x) | |
latent_model.summary() | |
latent_value = latent_model.predict(X, batch_size=512) | |
np.savez_compressed(outname, latent=latent_value, ground_truth=y) | |
if __name__ == "__main__": | |
train_autoencoder(False, False) | |
#get_latent_features(False, False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment