Created
February 21, 2022 13:55
-
-
Save rotemtzaban/d2f0a72e790a60d5390553048809e3d5 to your computer and use it in GitHub Desktop.
TL-ID and TG-ID metrics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import cv2 | |
import numpy as np | |
from insightface.app import FaceAnalysis | |
from tqdm import tqdm | |
IMG_EXTENSIONS = [ | |
'.jpg', '.JPG', '.jpeg', '.JPEG', | |
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' | |
] | |
def is_image_file(filename): | |
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | |
def make_dataset(directory): | |
images = [] | |
assert os.path.isdir(directory), '%s is not a valid directory' % directory | |
for file_name in sorted(os.listdir(directory)): | |
if is_image_file(file_name): | |
path = os.path.join(directory, file_name) | |
images.append(path) | |
return images | |
def get_face_embedding(app, image): | |
detections = app.get(image) | |
if len(detections) == 0: | |
return None | |
det = detections[0] | |
embed = det['embedding'] | |
embed_normalized = embed / np.linalg.norm(embed) | |
return embed_normalized | |
def measure_local_similarity(source_similarity_matrix, edit_similarity_matrix): | |
relative_similarity = edit_similarity_matrix / source_similarity_matrix | |
local_similarity = np.diag(relative_similarity, k=1) | |
mean_local_similarity = local_similarity.mean() | |
return mean_local_similarity | |
def measure_global_similarity(source_similarity_matrix, edit_similarity_matrix): | |
relative_similarity = edit_similarity_matrix / source_similarity_matrix | |
n = source_similarity_matrix.shape[0] | |
off_diag = ~np.eye(n).astype(bool) | |
off_diag_mean = np.mean(relative_similarity, where=off_diag) | |
return off_diag_mean | |
def measure_metrics(source_video_files, edited_video_files): | |
app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition']) | |
app.prepare(ctx_id=0, det_size=(640, 640)) | |
source_ds = make_dataset(source_video_files) | |
edited_ds = make_dataset(edited_video_files) | |
source_embeddings = [] | |
edited_embeddings = [] | |
for i, (source_path, edit_path) in enumerate(tqdm(zip(source_ds, edited_ds), total=len(source_ds), leave=False)): | |
source_image = cv2.imread(source_path) | |
edit_image = cv2.imread(edit_path) | |
source_embedding = get_face_embedding(app, source_image) | |
edited_embedding = get_face_embedding(app, edit_image) | |
if source_embedding is None or edited_embedding is None: | |
raise Exception(f'Failed detecting faces in frame {i} in video.') | |
source_embeddings.append(source_embedding) | |
edited_embeddings.append(edited_embedding) | |
source_embeddings = np.stack(source_embeddings) | |
edited_embeddings = np.stack(edited_embeddings) | |
source_similarity_matrix = source_embeddings @ source_embeddings.T | |
edit_similarity_matrix = edited_embeddings @ edited_embeddings.T | |
mean_local_similarity = measure_local_similarity(source_similarity_matrix, edit_similarity_matrix) | |
mean_global_similarity = measure_global_similarity(source_similarity_matrix, edit_similarity_matrix) | |
return {'tl_id': mean_local_similarity, 'tg_id': mean_global_similarity} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment