Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created November 23, 2018 12:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/fb9277cd2733660a224b8e14512cce96 to your computer and use it in GitHub Desktop.
Save koshian2/fb9277cd2733660a224b8e14512cce96 to your computer and use it in GitHub Desktop.
Train 5 networks at the same time
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, AveragePooling2D, GlobalAvgPool2D, Dense, Concatenate
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import History, Callback
from tensorflow.contrib.tpu.python.tpu import keras_support
from keras.objectives import categorical_crossentropy
import numpy as np
from keras.datasets import cifar10
from keras.utils import to_categorical
import os, pickle
def basic_conv_block(input, chs, reps):
x = input
for i in range(reps):
x =Conv2D(chs, 3, padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
return x
def create_cnn_output(input):
x = basic_conv_block(input, 64, 3)
x = AveragePooling2D(2)(x)
x = basic_conv_block(x, 128, 3)
x = AveragePooling2D(2)(x)
x = basic_conv_block(x, 256, 3)
x = GlobalAvgPool2D()(x)
return x
def create_joint_model():
input = Input((32,32,3))
outputs = [create_cnn_output(input) for i in range(5)]
x = Concatenate()([outputs[0], outputs[1]])
x = Concatenate()([x, outputs[2]])
x = Concatenate()([x, outputs[3]])
x = Concatenate()([x, outputs[4]])
x = Dense(10, activation="softmax")(x)
return Model(input, x)
class Checkpoint(Callback):
def __init__(self, model, filepath):
self.model = model
self.filepath = filepath
self.best_val_acc = 0.0
def on_epoch_end(self, epoch, logs):
if self.best_val_acc < logs["val_acc"]:
self.model.save_weights(self.filepath, save_format="h5")
self.best_val_acc = logs["val_acc"]
print("Weights saved.", self.best_val_acc)
def train():
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0
y_test_label = y_test.copy()
y_train, y_test = to_categorical(y_train), to_categorical(y_test)
model = create_joint_model()
#model = create_one_model()
batch_size = 1024
model.compile(tf.train.AdamOptimizer(), loss="categorical_crossentropy", metrics=["acc"])
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
model.summary()
hist = History()
cp = Checkpoint(model, f"single_weights.hdf5")
model.fit(X_train, y_train, batch_size=batch_size, callbacks=[hist, cp],
validation_data=(X_test, y_test), epochs=100)
# 最良のモデルの読み込み
model.load_weights(f"single_weights.hdf5")
test_result = model.evaluate(X_test, y_test)
print("Test", test_result)
hist.history["test_result"] = test_result
with open("single_fifth.dat", "wb") as fp:
pickle.dump(hist.history, fp)
if __name__ == "__main__":
K.clear_session()
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment