Last active
January 25, 2019 13:29
-
-
Save koshian2/ab4595f7378e53c0586005264f46d8a2 to your computer and use it in GitHub Desktop.
Train with 1000 triplet loss euclidean distance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras import layers | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.callbacks import Callback, History, LearningRateScheduler | |
import tensorflow.keras.backend as K | |
from tensorflow.contrib.tpu.python.tpu import keras_support | |
from train1000 import cifar10 | |
import numpy as np | |
from sklearn.metrics import accuracy_score, euclidean_distances | |
import os, json, tarfile | |
# VGG-like model | |
def create_siamese(latent_dims): | |
input = layers.Input((32, 32, 3)) | |
x = input | |
for i in range(4): | |
for j in range(3): | |
x = layers.Conv2D(64*(2**i), 3, padding="same")(x) | |
x = layers.BatchNormalization()(x) | |
x = layers.Activation("relu")(x) | |
if i != 3: | |
x = layers.AveragePooling2D(2)(x) | |
x = layers.GlobalAveragePooling2D()(x) | |
if latent_dims != 512: | |
x = layers.Dense(latent_dims)(x) | |
return Model(input, x) | |
MARGIN = 0.25 | |
# 上三角行列から対角成分を抜いた行列 | |
def upper_triangle(matrix): | |
upper = tf.matrix_band_part(matrix, 0, -1) | |
diagonal = tf.matrix_band_part(matrix, 0, 0) | |
diagonal_mask = tf.sign(tf.abs(tf.matrix_band_part(diagonal, 0, 0))) | |
return upper * (1.0 - diagonal_mask) | |
def euclidean_triplet_loss(label, embeddings): | |
# euclidean matrix | |
x1 = tf.expand_dims(embeddings, axis=0) | |
x2 = tf.expand_dims(embeddings, axis=1) | |
euclidean = tf.reduce_sum((x1-x2)**2, axis=-1) | |
# label equal matrix (* shape=[None, latent_dims]) | |
lb1 = tf.expand_dims(label[:, 0], axis=0) | |
lb2 = tf.expand_dims(label[:, 0], axis=1) | |
equal_mat = K.cast(tf.equal(lb1, lb2), "float32") | |
# postives tf.whereが使えないので総当たりにする | |
positive_flag = upper_triangle(equal_mat) | |
# positive以外は0を入れる | |
positive_dist = positive_flag*euclidean | |
positive_dist = tf.reshape(positive_dist, [-1,1]) | |
# negatives | |
negative_flag = upper_triangle(1.0-equal_mat) | |
distance_max = tf.reduce_max(euclidean, keepdims=True) | |
# negatives以外には(距離の最大値)を入れる | |
negative_dist = negative_flag*euclidean + (1.0-negative_flag)*distance_max | |
negative_dist = tf.reshape(negative_dist, [1,-1]) | |
# triplet loss | |
loss = tf.maximum(positive_dist - negative_dist + MARGIN, 0.0) | |
return tf.reduce_sum(loss) | |
class EmbeddingCallback(Callback): | |
def __init__(self, siamese_model, X_train, y_train, X_test, y_test): | |
self.model = siamese_model | |
self.X_train = X_train | |
self.y_train = y_train | |
self.X_test = X_test | |
self.y_test = y_test | |
self.test_nearest_neighbor_acc = [] | |
self.test_threshold_simple = [] | |
self.test_threshold_weighted = [] | |
# スレッショルド使って推定 | |
def pairwise_distance_matrix(self, anchor_embedding, target_embedding): | |
# メモリあふれるのでforループで計算 | |
distance = np.zeros((target_embedding.shape[0], anchor_embedding.shape[0]), dtype=np.float32) | |
for i in range(distance.shape[0]): | |
distance[i, :] = euclidean_distances(target_embedding[i,:].reshape(1,-1), anchor_embedding)[0] | |
return distance | |
# val_rate | |
def true_accept(self, distance, labels, threshold): | |
# uppper_mask | |
upper_mask = np.triu(np.ones(distance.shape, dtype=np.bool), k=1) | |
# true is same | |
truth_same = np.expand_dims(labels, axis=1) == np.expand_dims(labels, axis=0) | |
# pred same | |
pred_same = distance <= threshold | |
# true accept | |
ta = np.logical_and(pred_same, truth_same) | |
calc_true_same = np.logical_and(upper_mask, truth_same) | |
calc_ta = np.logical_and(upper_mask, ta) | |
# val rate | |
return np.sum(calc_ta) / np.sum(calc_true_same) | |
# accuracy | |
def thresholded_acuracy(self, distance, labels, threshold): | |
# uppper_mask | |
upper_mask = np.triu(np.ones(distance.shape, dtype=np.bool), k=1) | |
# true is same | |
truth_same = np.expand_dims(labels, axis=1) == np.expand_dims(labels, axis=0) | |
# pred same | |
pred_same = distance <= threshold | |
# true_positive | |
true_positive = np.logical_and(truth_same, pred_same) | |
# true_negative | |
true_negative = np.logical_and(np.logical_not(truth_same), np.logical_not(pred_same)) | |
# accuracy | |
accuracy_flag = np.logical_or(true_positive, true_negative) | |
accuracy_flag = np.logical_and(accuracy_flag, upper_mask) | |
return np.sum(accuracy_flag) / np.sum(upper_mask) | |
# f1 score | |
def thresholded_f1score(self, distance, labels, threshold): | |
# uppper_mask | |
upper_mask = np.triu(np.ones(distance.shape, dtype=np.bool), k=1) | |
# true is same | |
truth_same = np.expand_dims(labels, axis=1) == np.expand_dims(labels, axis=0) | |
# pred same | |
pred_same = distance <= threshold | |
# tp, tn, fp, fn | |
true_positive = np.logical_and(truth_same, pred_same) | |
true_negative = np.logical_and(np.logical_not(truth_same), np.logical_not(pred_same)) | |
false_positive = np.logical_and(np.logical_not(truth_same), pred_same) | |
false_negative = np.logical_and(truth_same, np.logical_not(pred_same)) | |
# uppermask -> num | |
tp = np.sum(np.logical_and(upper_mask, true_positive)) | |
tn = np.sum(np.logical_and(upper_mask, true_negative)) | |
fp = np.sum(np.logical_and(upper_mask, false_positive)) | |
fn = np.sum(np.logical_and(upper_mask, false_negative)) | |
# f1 score | |
precicion = tp / (tp + fp+ 1e-7) | |
recall = tp / (tp + fn + 1e-7) | |
f1 = 2 * recall * precicion / (precicion + recall + 1e-7) | |
return f1 | |
def find_threshold(self, distance_matrix, onehots): | |
assert distance_matrix.shape[0] == distance_matrix.shape[1] | |
n = distance_matrix.shape[0] | |
assert onehots.shape[0] == n | |
labels = np.sum(np.arange(onehots.shape[1]) * onehots, axis=1).astype(np.int32) | |
thresholds = np.arange(0.0, 4.0, 0.001, dtype=np.float32) | |
rate = np.zeros(thresholds.shape, dtype=np.float32) | |
for i, th in enumerate(thresholds): | |
rate[i] = self.true_accept(distance_matrix, labels, th) | |
#rate[i] = self.thresholded_acuracy(distance_matrix, labels, th) | |
#rate[i] = self.thresholded_f1score(distance_matrix, labels, th) | |
print("val_rate", rate) | |
#print("accuracy", rate) | |
#print("f1 score", rate) | |
best_idx = np.argmax(rate) | |
print(f"Best threshold : {thresholds[best_idx]}, Best VAL : {rate[best_idx]:.04}") | |
#print(f"Best threshold : {thresholds[best_idx]}, Best accuracy : {rate[best_idx]:.04}") | |
#print(f"Best threshold : {thresholds[best_idx]}, Best F1score : {rate[best_idx]:.04}") | |
return thresholds[best_idx] | |
# 最近傍を返す推定 | |
def one_nearest_neighbor(self, distance_matrix, anchor_onehots, target_onehots): | |
assert distance_matrix.shape[0] == target_onehots.shape[0] | |
assert distance_matrix.shape[1] == anchor_onehots.shape[0] | |
indices = np.argsort(distance_matrix, axis=-1) | |
# train の場合は正方行列 | |
if distance_matrix.shape[0] == distance_matrix.shape[1]: | |
index = indices[:, 1] | |
else: | |
index = indices[:, 0] | |
anchor_label = np.sum(np.arange(anchor_onehots.shape[1]).reshape(1,-1) * anchor_onehots, axis=-1) | |
y_pred = anchor_label[index] | |
y_true = np.sum(np.arange(target_onehots.shape[1]) * target_onehots, axis=-1) | |
return accuracy_score(y_true, y_pred) | |
# スレッショルドによる推定 | |
def thresholding_pred(self, distance_matrix, threshold, anchor_onehots, target_onehots, use_weighted): | |
assert distance_matrix.shape[0] == target_onehots.shape[0] | |
assert distance_matrix.shape[1] == anchor_onehots.shape[0] | |
thresholded_distance = np.maximum(threshold-distance_matrix, 0.0) | |
if not use_weighted: | |
thresholded_distance = np.sign(thresholded_distance) | |
y_pred = np.zeros(target_onehots.shape[0]) | |
print("thresholded num : ", np.mean(np.sum(np.logical_not(np.isclose(thresholded_distance,0.0)), axis=-1) )) | |
for i in range(y_pred.shape[0]): | |
score = thresholded_distance[i, :].reshape(-1,1) * anchor_onehots | |
pred_index = np.argmax(np.sum(score, axis=0)) | |
y_pred[i] = pred_index | |
y_true = np.sum(np.arange(target_onehots.shape[1]) * target_onehots, axis=-1) | |
print(np.bincount(y_pred.astype(np.int32))) | |
return accuracy_score(y_true, y_pred) | |
def on_epoch_end(self, epoch, logs): | |
train_embedding = self.model.predict(self.X_train) | |
test_embedding = self.model.predict(self.X_test) | |
# distance matrix | |
distance_train = self.pairwise_distance_matrix(train_embedding, train_embedding) | |
distance_test = self.pairwise_distance_matrix(train_embedding, test_embedding) | |
# threshold | |
print("") | |
threshold = self.find_threshold(distance_train, self.y_train) | |
# 最近傍推定 | |
print("Simple 1-Nearest Neighbor") | |
train_acc = self.one_nearest_neighbor(distance_train, self.y_train, self.y_train) | |
test_acc = self.one_nearest_neighbor(distance_test, self.y_train, self.y_test) | |
self.test_nearest_neighbor_acc.append(test_acc) | |
print(f"Train acc:{train_acc:.04}, Test acc:{test_acc:.04}, Best Test:{max(self.test_nearest_neighbor_acc):.04}") | |
# スレッショルドによる推定 | |
print("Simple thresholding") | |
train_acc = self.thresholding_pred(distance_train, threshold, self.y_train, self.y_train, False) | |
test_acc = self.thresholding_pred(distance_test, threshold, self.y_train, self.y_test, False) | |
self.test_threshold_simple.append(test_acc) | |
print(f"Train acc:{train_acc:.04}, Test acc:{test_acc:.04}, Best Test:{max(self.test_threshold_simple):.04}") | |
print("Weiged thresholding") | |
train_acc = self.thresholding_pred(distance_train, threshold, self.y_train, self.y_train, True) | |
test_acc = self.thresholding_pred(distance_test, threshold, self.y_train, self.y_test, True) | |
self.test_threshold_weighted.append(test_acc) | |
print(f"Train acc:{train_acc:.04}, Test acc:{test_acc:.04}, Best Test:{max(self.test_threshold_weighted):.04}") | |
def data_augmentation(image): | |
outputs = np.zeros(image.shape, dtype=np.float32) | |
# crop | |
crop_x = np.random.randint(0, 4) | |
crop_y = np.random.randint(0, 4) | |
outputs[crop_x:crop_x+28, crop_y:crop_y+28, :] = image[crop_x:crop_x+28, crop_y:crop_y+28, :] | |
# flip | |
if np.random.rand() >= 0.5: | |
outputs = outputs[:, ::-1, :] | |
return outputs | |
def generator(X, y, batch_size, use_augmentation, n_latent_dims): | |
while True: | |
X_cache, y_cache = [], [] | |
indices = np.random.permutation(X.shape[0]) | |
for i in indices: | |
if use_augmentation: | |
X_cache.append(data_augmentation(X[i])) | |
else: | |
X_cache.append(X[i]) | |
y_item = np.zeros(n_latent_dims) | |
y_item[0] = np.sum(np.arange(y.shape[1]) * y[i]) #1列目にラベルの数字を突っ込んでそれ以外はダミー | |
y_cache.append(y_item) | |
if(len(y_cache)==batch_size): #255で割ってあるから割らなくて良い | |
X_batch = np.asarray(X_cache, np.float32) | |
y_batch = np.asarray(y_cache, np.float32) | |
X_cache, y_cache = [], [] | |
yield X_batch, y_batch | |
def step_decay(epoch): | |
x = 1e-3 | |
if epoch >= 12: x = 2e-4 | |
if epoch >= 19: x = 4e-5 | |
return x | |
def train(n_dims, use_aug): | |
(X_train, y_train), (X_test, y_test) = cifar10() | |
siamese = create_siamese(n_dims) | |
siamese.compile("adam", euclidean_triplet_loss) | |
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"] | |
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url) | |
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) | |
siamese = tf.contrib.tpu.keras_to_tpu_model(siamese, strategy=strategy) | |
embed_cb = EmbeddingCallback(siamese, X_train, y_train, X_test, y_test) | |
hist = History() | |
scheduler = LearningRateScheduler(step_decay) | |
batch_size = 200 | |
siamese.fit_generator(generator(X_train, y_train, batch_size, use_aug, n_dims), | |
steps_per_epoch=X_train.shape[0]*400//batch_size, callbacks=[embed_cb, hist, scheduler], | |
max_queue_size=1, epochs=25) | |
if __name__ == "__main__": | |
K.clear_session() | |
print(512, "starts") | |
train(512, False) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment