Train with 1000 triplet loss euclidean distance
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import Callback, History, LearningRateScheduler
import tensorflow.keras.backend as K
from tensorflow.contrib.tpu.python.tpu import keras_support
from train1000 import cifar10
import numpy as np
from sklearn.metrics import accuracy_score, euclidean_distances
import os, json, tarfile
# VGG-like model
def create_siamese(latent_dims):
input = layers.Input((32, 32, 3))
x = input
for i in range(4):
for j in range(3):
x = layers.Conv2D(64*(2**i), 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
if i != 3:
x = layers.AveragePooling2D(2)(x)
x = layers.GlobalAveragePooling2D()(x)
if latent_dims != 512:
x = layers.Dense(latent_dims)(x)
return Model(input, x)
MARGIN = 0.25
# 上三角行列から対角成分を抜いた行列
def upper_triangle(matrix):
upper = tf.matrix_band_part(matrix, 0, -1)
diagonal = tf.matrix_band_part(matrix, 0, 0)
diagonal_mask = tf.sign(tf.abs(tf.matrix_band_part(diagonal, 0, 0)))
return upper * (1.0 - diagonal_mask)
def euclidean_triplet_loss(label, embeddings):
# euclidean matrix
x1 = tf.expand_dims(embeddings, axis=0)
x2 = tf.expand_dims(embeddings, axis=1)
euclidean = tf.reduce_sum((x1-x2)**2, axis=-1)
# label equal matrix (* shape=[None, latent_dims])
lb1 = tf.expand_dims(label[:, 0], axis=0)
lb2 = tf.expand_dims(label[:, 0], axis=1)
equal_mat = K.cast(tf.equal(lb1, lb2), "float32")
# postives tf.whereが使えないので総当たりにする
positive_flag = upper_triangle(equal_mat)
# positive以外は0を入れる
positive_dist = positive_flag*euclidean
positive_dist = tf.reshape(positive_dist, [-1,1])
# negatives
negative_flag = upper_triangle(1.0-equal_mat)
distance_max = tf.reduce_max(euclidean, keepdims=True)
# negatives以外には(距離の最大値)を入れる
negative_dist = negative_flag*euclidean + (1.0-negative_flag)*distance_max
negative_dist = tf.reshape(negative_dist, [1,-1])
# triplet loss
loss = tf.maximum(positive_dist - negative_dist + MARGIN, 0.0)
return tf.reduce_sum(loss)
class EmbeddingCallback(Callback):
def __init__(self, siamese_model, X_train, y_train, X_test, y_test):
self.model = siamese_model
self.X_train = X_train
self.y_train = y_train
self.X_test = X_test
self.y_test = y_test
self.test_nearest_neighbor_acc = []
self.test_threshold_simple = []
self.test_threshold_weighted = []
# スレッショルド使って推定
def pairwise_distance_matrix(self, anchor_embedding, target_embedding):
# メモリあふれるのでforループで計算
distance = np.zeros((target_embedding.shape[0], anchor_embedding.shape[0]), dtype=np.float32)
for i in range(distance.shape[0]):
distance[i, :] = euclidean_distances(target_embedding[i,:].reshape(1,-1), anchor_embedding)[0]
return distance
# val_rate
def true_accept(self, distance, labels, threshold):
# uppper_mask
upper_mask = np.triu(np.ones(distance.shape, dtype=np.bool), k=1)
# true is same
truth_same = np.expand_dims(labels, axis=1) == np.expand_dims(labels, axis=0)
# pred same
pred_same = distance <= threshold
# true accept
ta = np.logical_and(pred_same, truth_same)
calc_true_same = np.logical_and(upper_mask, truth_same)
calc_ta = np.logical_and(upper_mask, ta)
# val rate
return np.sum(calc_ta) / np.sum(calc_true_same)
# accuracy
def thresholded_acuracy(self, distance, labels, threshold):
# uppper_mask
upper_mask = np.triu(np.ones(distance.shape, dtype=np.bool), k=1)
# true is same
truth_same = np.expand_dims(labels, axis=1) == np.expand_dims(labels, axis=0)
# pred same
pred_same = distance <= threshold
# true_positive
true_positive = np.logical_and(truth_same, pred_same)
# true_negative
true_negative = np.logical_and(np.logical_not(truth_same), np.logical_not(pred_same))
# accuracy
accuracy_flag = np.logical_or(true_positive, true_negative)
accuracy_flag = np.logical_and(accuracy_flag, upper_mask)
return np.sum(accuracy_flag) / np.sum(upper_mask)
# f1 score
def thresholded_f1score(self, distance, labels, threshold):
# uppper_mask
upper_mask = np.triu(np.ones(distance.shape, dtype=np.bool), k=1)
# true is same
truth_same = np.expand_dims(labels, axis=1) == np.expand_dims(labels, axis=0)
# pred same
pred_same = distance <= threshold
# tp, tn, fp, fn
true_positive = np.logical_and(truth_same, pred_same)
true_negative = np.logical_and(np.logical_not(truth_same), np.logical_not(pred_same))
false_positive = np.logical_and(np.logical_not(truth_same), pred_same)
false_negative = np.logical_and(truth_same, np.logical_not(pred_same))
# uppermask -> num
tp = np.sum(np.logical_and(upper_mask, true_positive))
tn = np.sum(np.logical_and(upper_mask, true_negative))
fp = np.sum(np.logical_and(upper_mask, false_positive))
fn = np.sum(np.logical_and(upper_mask, false_negative))
# f1 score
precicion = tp / (tp + fp+ 1e-7)
recall = tp / (tp + fn + 1e-7)
f1 = 2 * recall * precicion / (precicion + recall + 1e-7)
return f1
def find_threshold(self, distance_matrix, onehots):
assert distance_matrix.shape[0] == distance_matrix.shape[1]
n = distance_matrix.shape[0]
assert onehots.shape[0] == n
labels = np.sum(np.arange(onehots.shape[1]) * onehots, axis=1).astype(np.int32)
thresholds = np.arange(0.0, 4.0, 0.001, dtype=np.float32)
rate = np.zeros(thresholds.shape, dtype=np.float32)
for i, th in enumerate(thresholds):
rate[i] = self.true_accept(distance_matrix, labels, th)
#rate[i] = self.thresholded_acuracy(distance_matrix, labels, th)
#rate[i] = self.thresholded_f1score(distance_matrix, labels, th)
print("val_rate", rate)
#print("accuracy", rate)
#print("f1 score", rate)
best_idx = np.argmax(rate)
print(f"Best threshold : {thresholds[best_idx]}, Best VAL : {rate[best_idx]:.04}")
#print(f"Best threshold : {thresholds[best_idx]}, Best accuracy : {rate[best_idx]:.04}")
#print(f"Best threshold : {thresholds[best_idx]}, Best F1score : {rate[best_idx]:.04}")
return thresholds[best_idx]
# 最近傍を返す推定
def one_nearest_neighbor(self, distance_matrix, anchor_onehots, target_onehots):
assert distance_matrix.shape[0] == target_onehots.shape[0]
assert distance_matrix.shape[1] == anchor_onehots.shape[0]
indices = np.argsort(distance_matrix, axis=-1)
# train の場合は正方行列
if distance_matrix.shape[0] == distance_matrix.shape[1]:
index = indices[:, 1]
index = indices[:, 0]
anchor_label = np.sum(np.arange(anchor_onehots.shape[1]).reshape(1,-1) * anchor_onehots, axis=-1)
y_pred = anchor_label[index]
y_true = np.sum(np.arange(target_onehots.shape[1]) * target_onehots, axis=-1)
return accuracy_score(y_true, y_pred)
# スレッショルドによる推定
def thresholding_pred(self, distance_matrix, threshold, anchor_onehots, target_onehots, use_weighted):
assert distance_matrix.shape[0] == target_onehots.shape[0]
assert distance_matrix.shape[1] == anchor_onehots.shape[0]
thresholded_distance = np.maximum(threshold-distance_matrix, 0.0)
if not use_weighted:
thresholded_distance = np.sign(thresholded_distance)
y_pred = np.zeros(target_onehots.shape[0])
print("thresholded num : ", np.mean(np.sum(np.logical_not(np.isclose(thresholded_distance,0.0)), axis=-1) ))
for i in range(y_pred.shape[0]):
score = thresholded_distance[i, :].reshape(-1,1) * anchor_onehots
pred_index = np.argmax(np.sum(score, axis=0))
y_pred[i] = pred_index
y_true = np.sum(np.arange(target_onehots.shape[1]) * target_onehots, axis=-1)
return accuracy_score(y_true, y_pred)
def on_epoch_end(self, epoch, logs):
train_embedding = self.model.predict(self.X_train)
test_embedding = self.model.predict(self.X_test)
# distance matrix
distance_train = self.pairwise_distance_matrix(train_embedding, train_embedding)
distance_test = self.pairwise_distance_matrix(train_embedding, test_embedding)
# threshold
threshold = self.find_threshold(distance_train, self.y_train)
# 最近傍推定
print("Simple 1-Nearest Neighbor")
train_acc = self.one_nearest_neighbor(distance_train, self.y_train, self.y_train)
test_acc = self.one_nearest_neighbor(distance_test, self.y_train, self.y_test)
print(f"Train acc:{train_acc:.04}, Test acc:{test_acc:.04}, Best Test:{max(self.test_nearest_neighbor_acc):.04}")
# スレッショルドによる推定
print("Simple thresholding")
train_acc = self.thresholding_pred(distance_train, threshold, self.y_train, self.y_train, False)
test_acc = self.thresholding_pred(distance_test, threshold, self.y_train, self.y_test, False)
print(f"Train acc:{train_acc:.04}, Test acc:{test_acc:.04}, Best Test:{max(self.test_threshold_simple):.04}")
print("Weiged thresholding")
train_acc = self.thresholding_pred(distance_train, threshold, self.y_train, self.y_train, True)
test_acc = self.thresholding_pred(distance_test, threshold, self.y_train, self.y_test, True)
print(f"Train acc:{train_acc:.04}, Test acc:{test_acc:.04}, Best Test:{max(self.test_threshold_weighted):.04}")
def data_augmentation(image):
outputs = np.zeros(image.shape, dtype=np.float32)
# crop
crop_x = np.random.randint(0, 4)
crop_y = np.random.randint(0, 4)
outputs[crop_x:crop_x+28, crop_y:crop_y+28, :] = image[crop_x:crop_x+28, crop_y:crop_y+28, :]
# flip
if np.random.rand() >= 0.5:
outputs = outputs[:, ::-1, :]
return outputs
def generator(X, y, batch_size, use_augmentation, n_latent_dims):
while True:
X_cache, y_cache = [], []
indices = np.random.permutation(X.shape[0])
for i in indices:
if use_augmentation:
y_item = np.zeros(n_latent_dims)
y_item[0] = np.sum(np.arange(y.shape[1]) * y[i]) #1列目にラベルの数字を突っ込んでそれ以外はダミー
if(len(y_cache)==batch_size): #255で割ってあるから割らなくて良い
X_batch = np.asarray(X_cache, np.float32)
y_batch = np.asarray(y_cache, np.float32)
X_cache, y_cache = [], []
yield X_batch, y_batch
def step_decay(epoch):
x = 1e-3
if epoch >= 12: x = 2e-4
if epoch >= 19: x = 4e-5
return x
def train(n_dims, use_aug):
(X_train, y_train), (X_test, y_test) = cifar10()
siamese = create_siamese(n_dims)
siamese.compile("adam", euclidean_triplet_loss)
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
siamese = tf.contrib.tpu.keras_to_tpu_model(siamese, strategy=strategy)
embed_cb = EmbeddingCallback(siamese, X_train, y_train, X_test, y_test)
hist = History()
scheduler = LearningRateScheduler(step_decay)
batch_size = 200
siamese.fit_generator(generator(X_train, y_train, batch_size, use_aug, n_dims),
steps_per_epoch=X_train.shape[0]*400//batch_size, callbacks=[embed_cb, hist, scheduler],
max_queue_size=1, epochs=25)
if __name__ == "__main__":
print(512, "starts")
train(512, False)
