Skip to content

Instantly share code, notes, and snippets.

@guipotje
Last active November 24, 2021 02:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save guipotje/7d67e5de338439579dc78449e78e188c to your computer and use it in GitHub Desktop.
Save guipotje/7d67e5de338439579dc78449e78e188c to your computer and use it in GitHub Desktop.
Using TPS warp files to generate ground-truth keypoint correspondences
'''
Please make sure that:
1 - You have the dependencies installed:
pip install opencv-python==4.5.3.56 torch tqdm scipy
2 - You have the py thin plate splines implementation, you can obtain it from its original repository:
git clone https://github.com/cheind/py-thin-plate-spline
ps: We are assuming that you are going to clone this repo in the same folder this script is placed.
'''
import os
import sys
if '__file__' in vars() or '__file__' in globals():
tps_repo_path = os.path.dirname(os.path.realpath(__file__)) + '/py-thin-plate-spline'
if not os.path.exists(tps_repo_path):
raise RuntimeError('TPS repository is required.. Please place the TPS repo in the same folder of this script')
else:
tps_repo_path = './py-thin-plate-spline'
if not os.path.exists(tps_repo_path):
raise RuntimeError('TPS repository is required')
sys.path.insert(0, tps_repo_path)
import thinplate as tps
import torch
import cv2
import glob
import tqdm
import argparse
import numpy as np
from scipy.spatial import KDTree
def write_sift(filepath, kps):
with open(filepath + '.sift', 'w') as f:
f.write('size, angle, x, y, octave\n')
for kp in kps:
f.write('%.2f, %.3f, %.2f, %.2f, %d\n'%(kp.size, kp.angle, kp.pt[0], kp.pt[1], kp.octave))
def write_matches(filepath, idx_ref, idx_tgt):
with open(filepath + '.match', 'w') as f:
f.write('idx_ref,idx_tgt\n')
for i in range(len(idx_ref)):
f.write('%d, %d\n'%(idx_ref[i], idx_tgt[i]))
def draw_cv_matches(src_img, tgt_img, src_kps, tgt_kps, gt_ref, gt_tgt):
dmatches = [cv2.DMatch(gt_ref[i], gt_tgt[i], 0.) for i in np.arange(len(gt_ref))]
img = cv2.drawMatches(src_img, src_kps ,tgt_img, tgt_kps, dmatches, None, flags = 0)
return img
def parseArg():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", help="Input directory for one or more datasets (use --dir for several)"
, required=True)
parser.add_argument("--tps_dir", help="Input directory containing the optimized TPS params for one or more datasets (use --dir for several)"
, required=True)
parser.add_argument("-d", "--dir", help="is a dir with several dataset folders?"
, action = 'store_true')
args = parser.parse_args()
return args
args = parseArg()
args.input = os.path.abspath(args.input)
if args.dir:
datasets = [d for d in glob.glob(args.input+'/*/*') if os.path.isdir(d)]
else:
datasets = [args.input]
tps_path = args.tps_dir
datasets = list(filter(lambda x: 'DeSurTSampled' in x or 'Kinect1' in x or 'Kinect2Sampled' in x or 'SimulationICCV' in x, datasets))
SIFT = cv2.SIFT_create(nfeatures = 2048, contrastThreshold=0.004)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for dataset in datasets:
if len(glob.glob(dataset + '/*.csv')) == 0: raise RuntimeError('Empty dataset with no .csv file')
targets = [os.path.splitext(t)[0] for t in glob.glob(dataset + '/*[0-9].csv')]
master = os.path.splitext(glob.glob(dataset + '/*master.csv')[0])[0]
ref_img = cv2.imread(master + '-rgb.png',0)
loading_path = os.path.join(tps_path, *dataset.split('/')[-2:])
if not os.path.exists(loading_path):
raise RuntimeError('There is no TPS directory in ' + loading_path)
ref_mask = cv2.imread(loading_path + '/' + os.path.basename(master) + '_objmask.png', 0)
ref_kps = SIFT.detect(ref_img, None)
print('Detected ref kps: ', len(ref_kps))
ref_kps = [kp for kp in ref_kps if ref_mask[int(kp.pt[1]), int(kp.pt[0])] > 0] #filter by object mask
write_sift(loading_path + '/' + os.path.basename(master), ref_kps)
for target in tqdm.tqdm(targets, desc = 'image pairs'):
loading_file = loading_path + '/' + os.path.basename(target)
theta_np = np.load(loading_file + '_theta.npy').astype(np.float32)
ctrl_pts = np.load(loading_file + '_ctrlpts.npy').astype(np.float32)
score = cv2.imread(loading_file + '_SSIM.png', 0) / 255.0
tgt_mask = cv2.imread(loading_file + '_objmask.png', 0)
if args.method != 'sift':
target_img = cv2.imread(target + '-rgb.png')
target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB)
else:
target_img = cv2.imread(target + '-rgb.png',0)
score_mask = score > 0.25
target_kps = SIFT.detect(target_img, None)
print('Detected target kps: ', len(target_kps))
target_kps = [kp for kp in target_kps if tgt_mask[int(kp.pt[1]), int(kp.pt[0])] > 0] #filter by object mask
target_kps = [kp for kp in target_kps if score_mask[int(kp.pt[1]), int(kp.pt[0])] == True] #filter by score map with very low confidences
norm_factor = np.array(target_img.shape[:2][::-1], dtype = np.float32)
theta = torch.tensor(theta_np, device= device)
tgt_coords = np.array([kp.pt for kp in target_kps], dtype = np.float32)
warped_coords = tps.torch.tps_sparse(theta, torch.tensor(ctrl_pts, device=device), torch.tensor(tgt_coords / norm_factor,
device=device)).squeeze(0).cpu().numpy() * norm_factor
tree = KDTree([kp.pt for kp in ref_kps])
dists, idxs_ref = tree.query(warped_coords)
px_thresh = 1.0 # 3.0
gt_tgt = np.arange(len(target_kps))[ dists < px_thresh] # Groundtruth indexes -- threshold is in pixels
gt_ref = idxs_ref[dists < px_thresh]
#filter repeated matches
_, uidxs = np.unique(gt_ref, return_index = True)
gt_ref = gt_ref[uidxs]
gt_tgt = gt_tgt[uidxs]
img_match = draw_cv_matches(ref_img, target_img, ref_kps, target_kps, gt_ref, gt_tgt)
cv2.imwrite('/homeLocal/guipotje/test2/' + os.path.basename(target) + '_match.png', img_match)
write_sift(loading_file, target_kps)
write_matches(loading_file, gt_ref, gt_tgt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment