Created
April 12, 2019 19:09
-
-
Save gaurav-gogia/db523d524a425a7baac9099ea6a00f54 to your computer and use it in GitHub Desktop.
trying facenet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import cv2 | |
import time | |
import glob | |
import facenet | |
import fr_utils | |
import numpy as np | |
import tensorflow as tf | |
from numpy import genfromtxt | |
from keras import backend as K | |
from keras.optimizers import Adam | |
from multiprocessing.dummy import Pool | |
K.set_image_data_format('channels_first') | |
PADDING = 50 | |
epsilon = 0.2 | |
imgsize = 96 | |
trainpath = "../ssd/gaurav_train/gaurav/*" | |
testpath = "../ssd/gaurav_test/gaurav/" | |
def cosine_similarity(src_rep, dst_rep): | |
a = np.matmul(np.transpose(src_rep), dst_rep) | |
b = np.sum(np.multiply(src_rep, src_rep)) | |
c = np.sum(np.multiply(dst_rep, dst_rep)) | |
return 1 - (a / (np.sqrt(b) * np.sqrt(c))) | |
def triplet_loss(ytrue, ypred, alpha=0.3): | |
anchor, positive, negative = ypred[0], ypred[1], ypred[2] | |
sub = tf.subtract(anchor, positive) | |
pos_dist = tf.reduce_sum(tf.square(sub), axis=-1) | |
neg_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, negative)), axis=-1) | |
basic_loss = tf.add(tf.subtract(pos_dist, neg_dist), alpha) | |
loss = tf.reduce_sum(tf.maximum(basic_loss, 0.0)) | |
return loss | |
model = facenet.FaceNet(input_shape=(3, imgsize, imgsize)) | |
model.compile(optimizer='adam', loss=triplet_loss, metrics=['accuracy']) | |
fr_utils.load_weights_from_FaceNet(model) | |
# fr_utils.load_weights() | |
def prep_db(): | |
db = {} | |
for file in glob.glob(trainpath): | |
identify = os.path.splitext(os.path.basename(file))[0] | |
db[identify] = fr_utils.img_path_to_encoding(file, model) | |
return db | |
def verify(image, db, model): | |
min_dist = 100 | |
identity = None | |
distlist = [] | |
encoding = fr_utils.img_to_encoding(image, model) | |
for (name, dbenc) in db.items(): | |
dist = np.linalg.norm(dbenc - encoding) | |
#dist = cosine_similarity(dbenc[0, :], encoding[0, :]) | |
if dist < min_dist: | |
print('Distance for %s is: %s' % (name, dist)) | |
min_dist = dist | |
identity = name | |
if dist not in distlist: | |
distlist.append(dist) | |
if min_dist > epsilon: | |
return None | |
else: | |
best_dist = np.prod(np.array(distlist)) | |
return str(identity), best_dist | |
db = prep_db() | |
print('Testing for: ', testpath+"gaurav.jpg") | |
img = cv2.imread(testpath+"gaurav.jpg") | |
print(verify(img, db, model)) | |
print() | |
print('------------------------------------------------------------------------------------------------') | |
print() | |
print('Testing for: ', testpath+"gav3.jpg") | |
img = cv2.imread(testpath+"gav3.jpg") | |
print(verify(img, db, model)) | |
print() | |
print('------------------------------------------------------------------------------------------------') | |
print() | |
print('Testing for: ', testpath+"gav4.jpg") | |
img = cv2.imread(testpath+"gav4.jpg") | |
print(verify(img, db, model)) | |
print() | |
print('##############################################################################################################') | |
print() | |
print('Testing for: ', testpath+"roma.jpg") | |
img = cv2.imread(testpath+"roma.jpg") | |
print(verify(img, db, model)) | |
print() | |
print('##############################################################################################################') | |
print() | |
print('Testing for: ', testpath+"tar2.jpeg") | |
img = cv2.imread(testpath+"tar2.jpeg") | |
print(verify(img, db, model)) | |
print() | |
print('##############################################################################################################') | |
print() | |
print('Testing for: ', testpath+"tarang.jpg") | |
img = cv2.imread(testpath+"tarang.jpg") | |
print(verify(img, db, model)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment