Last active
November 8, 2018 02:59
-
-
Save koshian2/8cb98ddac404be6c69b71b3217bd24c9 to your computer and use it in GitHub Desktop.
Natural or artificial parameters to CIFAR-10 classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Input, GlobalAveragePooling2D, AveragePooling2D, Dense | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.callbacks import History | |
import tensorflow.keras.backend as K | |
from tensorflow.contrib.tpu.python.tpu import keras_support | |
from keras.objectives import categorical_crossentropy | |
from keras.metrics import categorical_accuracy | |
from keras.datasets import cifar10 | |
from keras.utils import to_categorical | |
import numpy as np | |
import os, pickle, tarfile | |
def create_basic_block(input, filter, reps): | |
x = input | |
for i in range(reps): | |
x = Conv2D(filter, 3, padding="same")(x) | |
x = BatchNormalization()(x) | |
x = Activation("relu")(x) | |
return x | |
def create_model(): | |
input = Input((32, 32, 3)) | |
x = create_basic_block(input, 64, 3) | |
x = AveragePooling2D(2)(x) | |
x = create_basic_block(x, 128, 3) | |
x = AveragePooling2D(2)(x) | |
x = create_basic_block(x, 256, 3) | |
x = GlobalAveragePooling2D()(x) | |
x = Dense(10, activation="softmax")(x) | |
model = Model(input, x) | |
return model | |
from itertools import combinations | |
def loss_function_category_soft_multi(y_true, y_pred): | |
# まずはカテゴリー別の交差エントロピー | |
loss = categorical_crossentropy(y_true, y_pred) | |
# 10C2で起点と終点を選ぶ | |
for comb in combinations(range(10), 2): | |
probs_pred = K.expand_dims(K.sum(y_pred[:, comb[0]:(comb[1]+1)], axis=-1)) + K.epsilon() | |
probs_true = K.expand_dims(K.sum(y_true[:, comb[0]:(comb[1]+1)], axis=-1)) + K.epsilon() | |
loss += categorical_crossentropy(probs_true, probs_pred) | |
return loss | |
def acc_category(y_true, y_pred): | |
return categorical_accuracy(y_true[:, :10], y_pred) | |
def generator(X, y, batch_size, shuffle, return_category): | |
while True: | |
indices = np.arange(X.shape[0]) | |
if shuffle: | |
np.random.shuffle(indices) | |
for i in range(X.shape[0]//batch_size): | |
current_indices = indices[i*batch_size:((i+1)*batch_size)] | |
X_batch = (X[current_indices] / 255.0).astype(np.float32) | |
ncol = 12 if return_category else 10 | |
y_batch = to_categorical(y[current_indices], num_classes=ncol) | |
if return_category: | |
# 天然物かどうか(2~7) | |
y_batch[:,10] = ((y[current_indices,0] >= 2) * (y[current_indices,0] <= 7)).astype(np.float32) | |
# 人工物かどうか | |
y_batch[:,11] = 1.0 - y_batch[:,10] | |
yield X_batch, y_batch | |
def train(consider_category, use_tpu, rep): | |
model = create_model() | |
if consider_category: | |
model.compile(tf.train.AdamOptimizer(), loss=loss_function_category_soft_multi, metrics=[acc_category]) | |
else: | |
model.compile(tf.train.AdamOptimizer(), loss="categorical_crossentropy", metrics=[acc_category]) | |
if use_tpu: | |
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"] | |
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url) | |
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) | |
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy) | |
(X_train, y_train), (X_test, y_test) = cifar10.load_data() | |
batch_size = 512 | |
print("Trial", rep, "starts") | |
hist = History() | |
model.fit_generator(generator(X_train, y_train, batch_size, True, False), steps_per_epoch=50000//batch_size, | |
callbacks=[hist], validation_data=generator(X_test, y_test, batch_size, True, False), | |
validation_steps=10000//batch_size, epochs=100) | |
flag = "consider_category_multi" if consider_category else "normal" | |
if not os.path.exists(flag): | |
os.mkdir(flag) | |
with open(f"{flag}/history_{flag}_{rep}.dat", "wb") as fp: | |
pickle.dump(hist.history, fp) | |
return flag | |
def train_all(consider_category, use_tpu): | |
for i in range(5): | |
dir_name = train(consider_category, use_tpu, i) | |
with tarfile.open(f"{dir_name}.tar.gz", "w:gz") as tar: | |
tar.add(dir_name) | |
if __name__ == "__main__": | |
K.clear_session() | |
train_all(True, True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment