Skip to content

Instantly share code, notes, and snippets.

@RITIK-12
Last active May 13, 2021 18:21
Show Gist options
  • Save RITIK-12/2d717042224afaa9e6c34ab8e076e60b to your computer and use it in GitHub Desktop.
Save RITIK-12/2d717042224afaa9e6c34ab8e076e60b to your computer and use it in GitHub Desktop.
import fnmatch
import os
from matplotlib import pyplot as plt
import cv2
from facenet_pytorch import MTCNN, InceptionResnetV1
resnet = InceptionResnetV1(pretrained='vggface2').eval()
# Load the cascade
face_cascade = cv2.CascadeClassifier('/haarcascade_frontalface_default.xml')
def face_match(img_path, data_path): # img_path= location of photo, data_path= location of data.pt
# getting embedding matrix of the given img
img = cv2.imread(img_path)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Detect faces
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
# Draw rectangle around the faces
for (x, y, w, h) in faces:
crop_face = img[y:y+h, x:x+w]
img = cv2.imwrite(img_path,crop_face)
emb = resnet(img.unsqueeze(0)).detach() # detech is to make required gradient false
saved_data = torch.load('model.pt') # loading data.pt file
embedding_list = saved_data[0] # getting embedding data
name_list = saved_data[1] # getting list of names
dist_list = [] # list of matched distances, minimum distance is used to identify the person
for idx, emb_db in enumerate(embedding_list):
dist = torch.dist(emb, emb_db).item()
dist_list.append(dist)
idx_min = dist_list.index(min(dist_list))
return (name_list[idx_min], min(dist_list))
result = face_match('trainset/0006/0006_0000546/0006_0000546_script.jpg', '/model.pt')
print('Face matched with: ',result[0], 'With distance: ',result[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment