Last active
November 24, 2021 02:44
-
-
Save guipotje/7d67e5de338439579dc78449e78e188c to your computer and use it in GitHub Desktop.
Using TPS warp files to generate ground-truth keypoint correspondences
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Please make sure that: | |
1 - You have the dependencies installed: | |
pip install opencv-python==4.5.3.56 torch tqdm scipy | |
2 - You have the py thin plate splines implementation, you can obtain it from its original repository: | |
git clone https://github.com/cheind/py-thin-plate-spline | |
ps: We are assuming that you are going to clone this repo in the same folder this script is placed. | |
''' | |
import os | |
import sys | |
if '__file__' in vars() or '__file__' in globals(): | |
tps_repo_path = os.path.dirname(os.path.realpath(__file__)) + '/py-thin-plate-spline' | |
if not os.path.exists(tps_repo_path): | |
raise RuntimeError('TPS repository is required.. Please place the TPS repo in the same folder of this script') | |
else: | |
tps_repo_path = './py-thin-plate-spline' | |
if not os.path.exists(tps_repo_path): | |
raise RuntimeError('TPS repository is required') | |
sys.path.insert(0, tps_repo_path) | |
import thinplate as tps | |
import torch | |
import cv2 | |
import glob | |
import tqdm | |
import argparse | |
import numpy as np | |
from scipy.spatial import KDTree | |
def write_sift(filepath, kps): | |
with open(filepath + '.sift', 'w') as f: | |
f.write('size, angle, x, y, octave\n') | |
for kp in kps: | |
f.write('%.2f, %.3f, %.2f, %.2f, %d\n'%(kp.size, kp.angle, kp.pt[0], kp.pt[1], kp.octave)) | |
def write_matches(filepath, idx_ref, idx_tgt): | |
with open(filepath + '.match', 'w') as f: | |
f.write('idx_ref,idx_tgt\n') | |
for i in range(len(idx_ref)): | |
f.write('%d, %d\n'%(idx_ref[i], idx_tgt[i])) | |
def draw_cv_matches(src_img, tgt_img, src_kps, tgt_kps, gt_ref, gt_tgt): | |
dmatches = [cv2.DMatch(gt_ref[i], gt_tgt[i], 0.) for i in np.arange(len(gt_ref))] | |
img = cv2.drawMatches(src_img, src_kps ,tgt_img, tgt_kps, dmatches, None, flags = 0) | |
return img | |
def parseArg(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-i", "--input", help="Input directory for one or more datasets (use --dir for several)" | |
, required=True) | |
parser.add_argument("--tps_dir", help="Input directory containing the optimized TPS params for one or more datasets (use --dir for several)" | |
, required=True) | |
parser.add_argument("-d", "--dir", help="is a dir with several dataset folders?" | |
, action = 'store_true') | |
args = parser.parse_args() | |
return args | |
args = parseArg() | |
args.input = os.path.abspath(args.input) | |
if args.dir: | |
datasets = [d for d in glob.glob(args.input+'/*/*') if os.path.isdir(d)] | |
else: | |
datasets = [args.input] | |
tps_path = args.tps_dir | |
datasets = list(filter(lambda x: 'DeSurTSampled' in x or 'Kinect1' in x or 'Kinect2Sampled' in x or 'SimulationICCV' in x, datasets)) | |
SIFT = cv2.SIFT_create(nfeatures = 2048, contrastThreshold=0.004) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
for dataset in datasets: | |
if len(glob.glob(dataset + '/*.csv')) == 0: raise RuntimeError('Empty dataset with no .csv file') | |
targets = [os.path.splitext(t)[0] for t in glob.glob(dataset + '/*[0-9].csv')] | |
master = os.path.splitext(glob.glob(dataset + '/*master.csv')[0])[0] | |
ref_img = cv2.imread(master + '-rgb.png',0) | |
loading_path = os.path.join(tps_path, *dataset.split('/')[-2:]) | |
if not os.path.exists(loading_path): | |
raise RuntimeError('There is no TPS directory in ' + loading_path) | |
ref_mask = cv2.imread(loading_path + '/' + os.path.basename(master) + '_objmask.png', 0) | |
ref_kps = SIFT.detect(ref_img, None) | |
print('Detected ref kps: ', len(ref_kps)) | |
ref_kps = [kp for kp in ref_kps if ref_mask[int(kp.pt[1]), int(kp.pt[0])] > 0] #filter by object mask | |
write_sift(loading_path + '/' + os.path.basename(master), ref_kps) | |
for target in tqdm.tqdm(targets, desc = 'image pairs'): | |
loading_file = loading_path + '/' + os.path.basename(target) | |
theta_np = np.load(loading_file + '_theta.npy').astype(np.float32) | |
ctrl_pts = np.load(loading_file + '_ctrlpts.npy').astype(np.float32) | |
score = cv2.imread(loading_file + '_SSIM.png', 0) / 255.0 | |
tgt_mask = cv2.imread(loading_file + '_objmask.png', 0) | |
if args.method != 'sift': | |
target_img = cv2.imread(target + '-rgb.png') | |
target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB) | |
else: | |
target_img = cv2.imread(target + '-rgb.png',0) | |
score_mask = score > 0.25 | |
target_kps = SIFT.detect(target_img, None) | |
print('Detected target kps: ', len(target_kps)) | |
target_kps = [kp for kp in target_kps if tgt_mask[int(kp.pt[1]), int(kp.pt[0])] > 0] #filter by object mask | |
target_kps = [kp for kp in target_kps if score_mask[int(kp.pt[1]), int(kp.pt[0])] == True] #filter by score map with very low confidences | |
norm_factor = np.array(target_img.shape[:2][::-1], dtype = np.float32) | |
theta = torch.tensor(theta_np, device= device) | |
tgt_coords = np.array([kp.pt for kp in target_kps], dtype = np.float32) | |
warped_coords = tps.torch.tps_sparse(theta, torch.tensor(ctrl_pts, device=device), torch.tensor(tgt_coords / norm_factor, | |
device=device)).squeeze(0).cpu().numpy() * norm_factor | |
tree = KDTree([kp.pt for kp in ref_kps]) | |
dists, idxs_ref = tree.query(warped_coords) | |
px_thresh = 1.0 # 3.0 | |
gt_tgt = np.arange(len(target_kps))[ dists < px_thresh] # Groundtruth indexes -- threshold is in pixels | |
gt_ref = idxs_ref[dists < px_thresh] | |
#filter repeated matches | |
_, uidxs = np.unique(gt_ref, return_index = True) | |
gt_ref = gt_ref[uidxs] | |
gt_tgt = gt_tgt[uidxs] | |
img_match = draw_cv_matches(ref_img, target_img, ref_kps, target_kps, gt_ref, gt_tgt) | |
cv2.imwrite('/homeLocal/guipotje/test2/' + os.path.basename(target) + '_match.png', img_match) | |
write_sift(loading_file, target_kps) | |
write_matches(loading_file, gt_ref, gt_tgt) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment