Created
August 29, 2018 01:57
-
-
Save koshian2/cb31c3dce9ba05dabbe72136afb0daeb to your computer and use it in GitHub Desktop.
AnimeFace Character Dataset in deep learning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.applications.densenet import DenseNet121 | |
from keras.layers import GlobalAvgPool2D, Dense | |
from keras.models import Model | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.optimizers import SGD | |
from keras.regularizers import l2 | |
from keras.callbacks import ModelCheckpoint | |
import os | |
import glob | |
import pickle | |
import numpy as np | |
# クラス数 | |
nb_classes = len(os.listdir("animeface-character-dataset/train"))#176 | |
# 訓練サンプル数 | |
nb_train_images = len(glob.glob("animeface-character-dataset/train/*/*.png")) | |
# Mixup可能なDataGenerator(Data Augmentation) | |
class MixUpDataGenerator(ImageDataGenerator): | |
def __init__(self, mix_up_alpha, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.mix_up_alpha = mix_up_alpha | |
def mix_up(self, X1, y1, X2, y2): | |
assert X1.shape[0] == y1.shape[0] == X2.shape[0] == y2.shape[0] | |
batch_size = X1.shape[0] | |
l = np.random.beta(self.mix_up_alpha, self.mix_up_alpha, batch_size) | |
X_l = l.reshape(batch_size, 1, 1, 1) | |
y_l = l.reshape(batch_size, 1) | |
X = X1 * X_l + X2 * (1-X_l) | |
y = y1 * y_l + y2 * (1-y_l) | |
return X, y | |
def flow_from_directory(self, *args, **kwargs): | |
batches = super().flow_from_directory(*args, **kwargs) | |
while True: | |
batch_x, batch_y = next(batches) | |
# Mix-up | |
if self.mix_up_alpha > 0: | |
while True: | |
batch_x_2, batch_y_2 = next(batches) | |
m1, m2 = batch_x.shape[0], batch_x_2.shape[0] | |
if m1 < m2: | |
batch_x_2 = batch_x_2[:m1] | |
batch_y_2 = batch_y_2[:m1] | |
break | |
elif m1 == m2: | |
break | |
batch_x, batch_y = self.mix_up(batch_x, batch_y, batch_x_2, batch_y_2) | |
yield (batch_x, batch_y) | |
# Generator | |
traingen = MixUpDataGenerator( | |
rescale=1/255.0, | |
width_shift_range=15.0/160, | |
height_shift_range=15.0/160, | |
rotation_range=20, | |
horizontal_flip=True, | |
mix_up_alpha=0.2) | |
testgen = ImageDataGenerator( | |
rescale=1/255.0) | |
# DenseNetを転移学習 | |
dense = DenseNet121(include_top=False) | |
# conv3までを訓練しないようにする | |
for l in dense.layers: | |
if "conv4" in l.name: break | |
l.trainable = False | |
# Weight Decayを追加 | |
weight_decay = 0.01 | |
for l in dense.layers: | |
if not "_conv" in l.name: continue | |
if "pool" in l.name: continue | |
l.kernel_regularizer = l2(weight_decay) | |
# Outputs | |
x = GlobalAvgPool2D()(dense.output) | |
x = Dense(nb_classes, activation="softmax")(x) | |
model = Model(dense.input, x) | |
# Compile | |
model.compile(optimizer=SGD(lr=0.008, momentum=0.9), loss="categorical_crossentropy", metrics=["acc"]) | |
# Checkpoints | |
cp = ModelCheckpoint("model_{epoch:02d}-{val_loss:.2f}.h5") | |
# Train | |
history = model.fit_generator(traingen.flow_from_directory( | |
"animeface-character-dataset/train", target_size=(160, 160), class_mode="categorical", batch_size=128), | |
steps_per_epoch=nb_train_images/128, epochs=25, validation_data=testgen.flow_from_directory( | |
"animeface-character-dataset/test", target_size=(160, 160), class_mode="categorical"), | |
callbacks=[cp]).history | |
with open("history.dat", "wb") as fp: | |
pickle.dump(history, fp) | |
# クラス別に精度を集計 | |
X_test, y_test = testgen.flow_from_directory( | |
"animeface-character-dataset/test", target_size=(160, 160), class_mode="sparse", batch_size=4428, shuffle=False).next() | |
from keras.models import load_model | |
best_model = load_model("model_12-1.03.h5") | |
pred = best_model.predict(X_test) | |
pred_class = np.argmax(pred, axis=1) | |
acc_by_class = np.zeros(nb_classes) | |
for i in range(nb_classes): | |
acc_by_class[i] = np.mean(pred_class[y_test==i] == i) | |
with open("pred_acc.txt", "w") as fp: | |
fp.write("\n".join([str(x) for x in acc_by_class])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment