Skip to content

Instantly share code, notes, and snippets.

@rotemtzaban
Created February 21, 2022 13:55
Show Gist options
  • Save rotemtzaban/d2f0a72e790a60d5390553048809e3d5 to your computer and use it in GitHub Desktop.
Save rotemtzaban/d2f0a72e790a60d5390553048809e3d5 to your computer and use it in GitHub Desktop.
TL-ID and TG-ID metrics
import os
import cv2
import numpy as np
from insightface.app import FaceAnalysis
from tqdm import tqdm
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(directory):
images = []
assert os.path.isdir(directory), '%s is not a valid directory' % directory
for file_name in sorted(os.listdir(directory)):
if is_image_file(file_name):
path = os.path.join(directory, file_name)
images.append(path)
return images
def get_face_embedding(app, image):
detections = app.get(image)
if len(detections) == 0:
return None
det = detections[0]
embed = det['embedding']
embed_normalized = embed / np.linalg.norm(embed)
return embed_normalized
def measure_local_similarity(source_similarity_matrix, edit_similarity_matrix):
relative_similarity = edit_similarity_matrix / source_similarity_matrix
local_similarity = np.diag(relative_similarity, k=1)
mean_local_similarity = local_similarity.mean()
return mean_local_similarity
def measure_global_similarity(source_similarity_matrix, edit_similarity_matrix):
relative_similarity = edit_similarity_matrix / source_similarity_matrix
n = source_similarity_matrix.shape[0]
off_diag = ~np.eye(n).astype(bool)
off_diag_mean = np.mean(relative_similarity, where=off_diag)
return off_diag_mean
def measure_metrics(source_video_files, edited_video_files):
app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
app.prepare(ctx_id=0, det_size=(640, 640))
source_ds = make_dataset(source_video_files)
edited_ds = make_dataset(edited_video_files)
source_embeddings = []
edited_embeddings = []
for i, (source_path, edit_path) in enumerate(tqdm(zip(source_ds, edited_ds), total=len(source_ds), leave=False)):
source_image = cv2.imread(source_path)
edit_image = cv2.imread(edit_path)
source_embedding = get_face_embedding(app, source_image)
edited_embedding = get_face_embedding(app, edit_image)
if source_embedding is None or edited_embedding is None:
raise Exception(f'Failed detecting faces in frame {i} in video.')
source_embeddings.append(source_embedding)
edited_embeddings.append(edited_embedding)
source_embeddings = np.stack(source_embeddings)
edited_embeddings = np.stack(edited_embeddings)
source_similarity_matrix = source_embeddings @ source_embeddings.T
edit_similarity_matrix = edited_embeddings @ edited_embeddings.T
mean_local_similarity = measure_local_similarity(source_similarity_matrix, edit_similarity_matrix)
mean_global_similarity = measure_global_similarity(source_similarity_matrix, edit_similarity_matrix)
return {'tl_id': mean_local_similarity, 'tg_id': mean_global_similarity}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment