Skip to content

Instantly share code, notes, and snippets.

@nwatab
Last active August 2, 2019 05:34
Show Gist options
  • Save nwatab/531d57166a6f627062e929132b231660 to your computer and use it in GitHub Desktop.
Save nwatab/531d57166a6f627062e929132b231660 to your computer and use it in GitHub Desktop.
Grad-CAM++ by fine-tuning VGG16 for anomaly detection
  1. Fine-tuning with DAGM2007 dataset. 109 Epochs was enough. It took 8 hours with a Tesla K80 (12GB Memory, 61GB RAM, 100GB SSD)

https://resources.mpi-inf.mpg.de/conference/dagm/2007/prizes.html

  1. python heatmap.py image.jpg

image

OK Class3
class3_90class3_90 png_GCAM++_VGG16

OK Class5
class5_121class5_121 png_GCAM++_VGG16

NG Class3
class3_121class3_121 png_GCAM++_VGG16

NG Class5
class5_121class5_121 png_GCAM++_VGG16

Thanks to
最新手法!「Grad-CAM++」のレビューと実装
https://qiita.com/Dason08/items/a8013b3fa4d303f5c41c
kerasでvgg16とGrad-CAMの実装による異常検出および異常箇所の可視化
https://qiita.com/T_Tao/items/0e869e440067518b6b58

import sys
import cv2
import keras.backend as K
from keras.models import load_model
from keras.preprocessing.image import load_img, img_to_array
import numpy as np
K.set_learning_phase(0)
def Grad_Cam_plus_plus(input_model, layer_name, img_array):
model = input_model
(row, col, _) = img_array.shape
# 前処理
X = np.expand_dims(img_array, axis=0)
X = X.astype('float32')
preprocessed_input = X / 255.0
# 予測クラスの算出
predictions = model.predict(preprocessed_input)
class_idx = np.argmax(predictions[0])
# 使用する重みの抽出、高階微分の計算
class_output = model.output[:, class_idx]
class_output = model.layers[-1].output
conv_output = model.get_layer(layer_name).output
grads = K.gradients(class_output, conv_output)[0]
#first_derivative:1階微分
first_derivative = K.exp(class_output)[0][class_idx] * grads
#second_derivative:2階微分
second_derivative = K.exp(class_output)[0][class_idx] * grads * grads
#third_derivative:3階微分
third_derivative = K.exp(class_output)[0][class_idx] * grads * grads * grads
#関数の定義
gradient_function = K.function([model.input], [conv_output, first_derivative, second_derivative, third_derivative]) # model.inputを入力すると、conv_outputとgradsを出力する関数
conv_output, conv_first_grad, conv_second_grad, conv_third_grad = gradient_function([preprocessed_input])
conv_output, conv_first_grad, conv_second_grad, conv_third_grad = conv_output[0], conv_first_grad[0], conv_second_grad[0], conv_third_grad[0]
#alphaを求める
global_sum = np.sum(conv_output.reshape((-1, conv_first_grad.shape[2])), axis=0)
alpha_num = conv_second_grad
alpha_denom = conv_second_grad*2.0 + conv_third_grad*global_sum.reshape((1,1,conv_first_grad.shape[2]))
alpha_denom = np.where(alpha_denom!=0.0, alpha_denom, np.ones(alpha_denom.shape))
alphas = alpha_num / alpha_denom
#alphaの正規化
alpha_normalization_constant = np.sum(np.sum(alphas, axis = 0), axis = 0)
alpha_normalization_constant_processed = np.where(alpha_normalization_constant != 0.0, alpha_normalization_constant, np.ones(alpha_normalization_constant.shape))
alphas /= alpha_normalization_constant_processed.reshape((1,1,conv_first_grad.shape[2]))
#wの計算
weights = np.maximum(conv_first_grad, 0.0)
deep_linearization_weights = np.sum((weights * alphas).reshape((-1, conv_first_grad.shape[2])))
#Lの計算
grad_CAM_map = np.sum(deep_linearization_weights * conv_output, axis=2)
grad_CAM_map = np.maximum(grad_CAM_map, 0)
grad_CAM_map = grad_CAM_map / np.max(grad_CAM_map)
#ヒートマップを描く
grad_CAM_map = cv2.resize(grad_CAM_map, (row, col), cv2.INTER_LINEAR)
jetcam = cv2.applyColorMap(np.uint8(255 * grad_CAM_map), cv2.COLORMAP_JET) # モノクロ画像に疑似的に色をつける
jetcam = (np.float32(jetcam) + img_array / 2) # もとの画像に合成
return jetcam
row = 512
col = 512
img_shape = (row, col, 3)
model = load_model('model.h5')
target_layer = 'block5_conv3'
image_path = sys.argv[1] or 'dagm2007andothers/validation/NG/class1_1.png'
img = img_to_array(load_img(image_path, target_size=(row, col)))
img_GCAMplusplus = Grad_Cam_plus_plus(model, target_layer, img)
img_Gplusplusname = image_path + "_GCAM++_VGG16.jpg"
cv2.imwrite(img_Gplusplusname, img_GCAMplusplus)
# -=- encoding: UTF-8 -=-
from keras.callbacks import BaseLogger, ModelCheckpoint, TensorBoard
from keras.layers import Activation, Input, Flatten, Dense, BatchNormalization
from keras.applications import VGG16
from keras.models import Model
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
# Prepare Data like
# './DAGM2007/train/OK/class[1-6]_[1-120].png'
# './DAGM2007/train/NG/class[1-6]_[1-120].png'
# './DAGM2007/test/OK/class[1-6]_[121-150].png'
# './DAGM2007/test/NG/class[1-6]_[121-150].png'
datapath = './DAGM2007/'
batch_size = 16
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
datapath + 'train',
target_size=(512, 512),
batch_size=batch_size,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
datapath + 'validation',
target_size=(512, 512),
batch_size=batch_size,
class_mode='binary')
# Model
img_shape = (512, 512, 3)
vgg = VGG16(weights='imagenet', include_top=False, input_shape=img_shape)
h = Flatten()(vgg.output)
h = Dense(256)(h)
h = Activation('relu')(h)
h = BatchNormalization()(h)
h = Dense(1)(h)
pred = Activation('sigmoid')(h)
model = Model(vgg.input, pred)
model.summary()
epochs = 120
train_samples = 1440
validation_samples = 360
model.compile(loss='binary_crossentropy',
optimizer=Adam(lr=1e-4),
metrics=['accuracy'])
base_cb = BaseLogger(stateful_metrics=None)
model_cb = ModelCheckpoint('model.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)
tb_cb = TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True)
history = model.fit_generator(
train_generator,
steps_per_epoch=train_samples // batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=validation_samples // batch_size,
shuffle=True,
callbacks = [base_cb, model_cb, tb_cb]
)
model.save_weights('first_try.h5')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment