Created
November 23, 2018 12:00
-
-
Save koshian2/fb9277cd2733660a224b8e14512cce96 to your computer and use it in GitHub Desktop.
Train 5 networks at the same time
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, AveragePooling2D, GlobalAvgPool2D, Dense, Concatenate | |
from tensorflow.keras.models import Model | |
import tensorflow.keras.backend as K | |
from tensorflow.keras.callbacks import History, Callback | |
from tensorflow.contrib.tpu.python.tpu import keras_support | |
from keras.objectives import categorical_crossentropy | |
import numpy as np | |
from keras.datasets import cifar10 | |
from keras.utils import to_categorical | |
import os, pickle | |
def basic_conv_block(input, chs, reps): | |
x = input | |
for i in range(reps): | |
x =Conv2D(chs, 3, padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
return x | |
def create_cnn_output(input): | |
x = basic_conv_block(input, 64, 3) | |
x = AveragePooling2D(2)(x) | |
x = basic_conv_block(x, 128, 3) | |
x = AveragePooling2D(2)(x) | |
x = basic_conv_block(x, 256, 3) | |
x = GlobalAvgPool2D()(x) | |
return x | |
def create_joint_model(): | |
input = Input((32,32,3)) | |
outputs = [create_cnn_output(input) for i in range(5)] | |
x = Concatenate()([outputs[0], outputs[1]]) | |
x = Concatenate()([x, outputs[2]]) | |
x = Concatenate()([x, outputs[3]]) | |
x = Concatenate()([x, outputs[4]]) | |
x = Dense(10, activation="softmax")(x) | |
return Model(input, x) | |
class Checkpoint(Callback): | |
def __init__(self, model, filepath): | |
self.model = model | |
self.filepath = filepath | |
self.best_val_acc = 0.0 | |
def on_epoch_end(self, epoch, logs): | |
if self.best_val_acc < logs["val_acc"]: | |
self.model.save_weights(self.filepath, save_format="h5") | |
self.best_val_acc = logs["val_acc"] | |
print("Weights saved.", self.best_val_acc) | |
def train(): | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
X_train, X_test = X_train / 255.0, X_test / 255.0 | |
y_test_label = y_test.copy() | |
y_train, y_test = to_categorical(y_train), to_categorical(y_test) | |
model = create_joint_model() | |
#model = create_one_model() | |
batch_size = 1024 | |
model.compile(tf.train.AdamOptimizer(), loss="categorical_crossentropy", metrics=["acc"]) | |
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"] | |
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url) | |
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) | |
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) | |
model.summary() | |
hist = History() | |
cp = Checkpoint(model, f"single_weights.hdf5") | |
model.fit(X_train, y_train, batch_size=batch_size, callbacks=[hist, cp], | |
validation_data=(X_test, y_test), epochs=100) | |
# 最良のモデルの読み込み | |
model.load_weights(f"single_weights.hdf5") | |
test_result = model.evaluate(X_test, y_test) | |
print("Test", test_result) | |
hist.history["test_result"] = test_result | |
with open("single_fifth.dat", "wb") as fp: | |
pickle.dump(hist.history, fp) | |
if __name__ == "__main__": | |
K.clear_session() | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment