Skip to content

Instantly share code, notes, and snippets.

@koshian2
Last active November 8, 2018 02:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/8cb98ddac404be6c69b71b3217bd24c9 to your computer and use it in GitHub Desktop.
Save koshian2/8cb98ddac404be6c69b71b3217bd24c9 to your computer and use it in GitHub Desktop.
Natural or artificial parameters to CIFAR-10 classification
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Input, GlobalAveragePooling2D, AveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import History
import tensorflow.keras.backend as K
from tensorflow.contrib.tpu.python.tpu import keras_support
from keras.objectives import categorical_crossentropy
from keras.metrics import categorical_accuracy
from keras.datasets import cifar10
from keras.utils import to_categorical
import numpy as np
import os, pickle, tarfile
def create_basic_block(input, filter, reps):
x = input
for i in range(reps):
x = Conv2D(filter, 3, padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
return x
def create_model():
input = Input((32, 32, 3))
x = create_basic_block(input, 64, 3)
x = AveragePooling2D(2)(x)
x = create_basic_block(x, 128, 3)
x = AveragePooling2D(2)(x)
x = create_basic_block(x, 256, 3)
x = GlobalAveragePooling2D()(x)
x = Dense(10, activation="softmax")(x)
model = Model(input, x)
return model
from itertools import combinations
def loss_function_category_soft_multi(y_true, y_pred):
# まずはカテゴリー別の交差エントロピー
loss = categorical_crossentropy(y_true, y_pred)
# 10C2で起点と終点を選ぶ
for comb in combinations(range(10), 2):
probs_pred = K.expand_dims(K.sum(y_pred[:, comb[0]:(comb[1]+1)], axis=-1)) + K.epsilon()
probs_true = K.expand_dims(K.sum(y_true[:, comb[0]:(comb[1]+1)], axis=-1)) + K.epsilon()
loss += categorical_crossentropy(probs_true, probs_pred)
return loss
def acc_category(y_true, y_pred):
return categorical_accuracy(y_true[:, :10], y_pred)
def generator(X, y, batch_size, shuffle, return_category):
while True:
indices = np.arange(X.shape[0])
if shuffle:
np.random.shuffle(indices)
for i in range(X.shape[0]//batch_size):
current_indices = indices[i*batch_size:((i+1)*batch_size)]
X_batch = (X[current_indices] / 255.0).astype(np.float32)
ncol = 12 if return_category else 10
y_batch = to_categorical(y[current_indices], num_classes=ncol)
if return_category:
# 天然物かどうか(2~7)
y_batch[:,10] = ((y[current_indices,0] >= 2) * (y[current_indices,0] <= 7)).astype(np.float32)
# 人工物かどうか
y_batch[:,11] = 1.0 - y_batch[:,10]
yield X_batch, y_batch
def train(consider_category, use_tpu, rep):
model = create_model()
if consider_category:
model.compile(tf.train.AdamOptimizer(), loss=loss_function_category_soft_multi, metrics=[acc_category])
else:
model.compile(tf.train.AdamOptimizer(), loss="categorical_crossentropy", metrics=[acc_category])
if use_tpu:
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
batch_size = 512
print("Trial", rep, "starts")
hist = History()
model.fit_generator(generator(X_train, y_train, batch_size, True, False), steps_per_epoch=50000//batch_size,
callbacks=[hist], validation_data=generator(X_test, y_test, batch_size, True, False),
validation_steps=10000//batch_size, epochs=100)
flag = "consider_category_multi" if consider_category else "normal"
if not os.path.exists(flag):
os.mkdir(flag)
with open(f"{flag}/history_{flag}_{rep}.dat", "wb") as fp:
pickle.dump(hist.history, fp)
return flag
def train_all(consider_category, use_tpu):
for i in range(5):
dir_name = train(consider_category, use_tpu, i)
with tarfile.open(f"{dir_name}.tar.gz", "w:gz") as tar:
tar.add(dir_name)
if __name__ == "__main__":
K.clear_session()
train_all(True, True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment