Created
February 10, 2020 10:02
-
-
Save fede-vaccaro/3f27063c189d84542c0e159ac4cd795f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
import shutil | |
import cv2 | |
import matplotlib.pyplot as plt | |
import networkx | |
import numpy as np | |
import paths | |
from networkx import ( | |
draw, | |
Graph, | |
) | |
from tqdm import tqdm | |
os.environ["OMP_NUM_THREADS"] = "12" | |
source_dir = paths.landmarks_path | |
dest_dir_name = "landmark_cleaned" | |
base_dir = 'landmark/' | |
full_dest_dir_name = os.path.join( | |
base_dir, | |
dest_dir_name | |
) | |
if not os.path.exists(full_dest_dir_name): | |
os.mkdir(full_dest_dir_name) | |
all_labels = [os.path.join(source_dir, l) for l in os.listdir(source_dir)] | |
random.shuffle(all_labels) | |
RATIO_THR = 0.7 # Lower values mean more aggressive filtering. | |
def ratio_test_(m): | |
return m[0].distance < m[1].distance * RATIO_THR | |
for label in all_labels: | |
dir = label | |
pics = [os.path.join(dir, im) for im in os.listdir(dir)] | |
pairwise_matches = {} | |
new_dir_name = "cleaned_{}".format(os.path.split(dir)[-1]) | |
new_dir_name = os.path.join(full_dest_dir_name, new_dir_name) | |
if not os.path.exists(new_dir_name): | |
os.mkdir(new_dir_name) | |
else: | |
print("Skipping: {}, already existent.".format(label)) | |
continue | |
for i, im0_path in enumerate(pics): | |
for j, im1_path in enumerate(pics[i + 1:]): | |
match = (im0_path, im1_path) | |
pairwise_matches[match] = 0 | |
undirected = Graph() | |
undirected.add_edges_from(list(pairwise_matches.keys())) | |
pw_matches = len(pairwise_matches.keys()) | |
th = int(np.round(np.log(pw_matches))) | |
if pw_matches > 2000: | |
print("Skipping: {}. Too much images.".format(label)) | |
continue | |
# draw(undirected, with_labels=False) | |
# plt.show() | |
im_dict = {} | |
for im in pics: | |
img = cv2.imread(im, cv2.IMREAD_COLOR) | |
im_dict[im] = img | |
for match in tqdm(pairwise_matches.keys()): | |
try: | |
img0 = im_dict[match[0]] | |
img1 = im_dict[match[1]] | |
shape_img0 = img0.shape | |
img1 = cv2.resize(img1, (shape_img0[1], shape_img0[0])) | |
imgs_list = [img0, img1] | |
detector = cv2.xfeatures2d.SIFT_create() | |
kps0, fea0 = detector.detectAndCompute(img0, None) | |
kps1, fea1 = detector.detectAndCompute(img1, None) | |
matcher = cv2.BFMatcher_create(cv2.NORM_L2, False) | |
matches01 = matcher.knnMatch(fea0, fea1, k=2) | |
matches10 = matcher.knnMatch(fea1, fea0, k=2) | |
good_matches01 = [gm[0] for gm in matches01 if ratio_test_(gm)] | |
good_matches10 = [gm[0] for gm in matches10 if ratio_test_(gm)] | |
good_matches10_ = {(m.trainIdx, m.queryIdx) for m in good_matches10} | |
final_matches = [m for m in good_matches01 if (m.queryIdx, m.trainIdx) in good_matches10_] | |
pts0 = np.float32([kps0[m.queryIdx].pt for m in final_matches]).reshape(-1, 2) | |
pts1 = np.float32([kps1[m.trainIdx].pt for m in final_matches]).reshape(-1, 2) | |
H, mask = cv2.findHomography(pts0, pts1, cv2.RANSAC, 3.0) | |
dbg_img = cv2.drawMatches(img0, kps0, img1, kps1, final_matches, None) | |
positive_mask = np.array([m for m in mask if m]).sum() | |
pairwise_matches[match] = int(positive_mask) | |
dbg_img = cv2.drawMatches(img0, kps0, img1, kps1, [m for i, m in | |
enumerate(final_matches) if mask[i]], None) | |
except Exception as e: | |
pass | |
thresholded_keys = [] | |
# th = 8 | |
for key in pairwise_matches.keys(): | |
n_matches = pairwise_matches[key] | |
if n_matches <= th: | |
thresholded_keys.append(key) | |
for key in thresholded_keys: | |
del pairwise_matches[key] | |
undirected = Graph() | |
undirected.add_edges_from(list(pairwise_matches.keys())) | |
draw(undirected, with_labels=False) | |
# plt.show() | |
pic_set = set() | |
for key in pairwise_matches.keys(): | |
a, b = key | |
pic_set.add(a) | |
pic_set.add(b) | |
for i, component in enumerate(networkx.connected_components(undirected)): | |
for pic in component: | |
new_component_dir_name = "{}/component_{}".format(new_dir_name, i) | |
if not os.path.exists(new_component_dir_name): | |
os.mkdir(new_component_dir_name) | |
shutil.copy(pic, | |
os.path.join(new_component_dir_name, os.path.split(pic)[-1]) | |
) | |
# for pic in pic_set: | |
# shutil.copy(pic, | |
# os.path.join(new_dir_name, os.path.split(pic)[-1]) | |
# ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment