Last active
June 23, 2019 10:17
-
-
Save etienne87/257e1acd5875f0813fc0665acf75087a to your computer and use it in GitHub Desktop.
image stitching experiment
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import numpy as np | |
import imutils | |
import cv2 | |
""" | |
Homecooked RANSAC, we will see if we manage to replace cv2.findHomography(..., cv2.RANSAC) | |
""" | |
def ransac(data, model, n, k, t, d): | |
bestfit = None | |
besterr = np.inf | |
best_inlier_idxs = None | |
status = None | |
idx = np.arange(len(data)) | |
for i in range(k): | |
jdx = np.random.choice(idx, size=n, replace=True) | |
maybeinliers = data[jdx] | |
maybemodel = model.fit(maybeinliers) | |
all_errors = model.get_error(data, maybemodel) | |
inliers_idxs = idx[all_errors < t] | |
#print('all errors: ', all_errors.min(), all_errors.max(), ' len inliers: ', len(inliers_idxs), ' thresh: ', t) | |
if len(inliers_idxs) > d: | |
betterdata = data[inliers_idxs] | |
bettermodel = model.fit(betterdata) | |
better_errs = model.get_error(betterdata, bettermodel) | |
thiserr = np.mean(better_errs) | |
if thiserr < besterr: | |
#print('better error: ', model.get_error(data, bettermodel).mean()) | |
bestfit = bettermodel | |
besterr = thiserr | |
best_inlier_idxs = inliers_idxs | |
status = all_errors < t | |
return bestfit, status, all_errors | |
def normalize(point_list, point_axis=1): | |
m = np.mean(point_list[:2], axis=point_axis) | |
#max_std = min(np.std(point_list[:2], axis=point_axis)) + 1e-9 | |
std = np.std(point_list[:2], axis=point_axis) + 1e-9 | |
c = np.diag([1.0 / std[0], 1.0 / std[1], 1]) | |
c[0][2] = -m[0] / std[0] | |
c[1][2] = -m[1] / std[1] | |
return np.dot(c, point_list), c | |
def calculate_homography_matrix(origin, dest): | |
""" | |
9 point algorithm. | |
------------------ | |
See http://www.cse.psu.edu/~rtc12/CSE486/lecture16.pdf | |
1. normalize points, get normalization matrices | |
2. by expressing x' = H_row1.dot(x)/H_row3.dot(x); y' = H_row2.dot(x)/H_row3.dot(x) | |
we get 2 equations per match (pair origin/ dest) | |
3. solve homogeneous equation Ax = 0 using SVD decomposition. | |
:param origin: | |
:param dest: | |
:return: | |
""" | |
origin, c1 = normalize(origin) | |
dest, c2 = normalize(dest) | |
nbr_correspondences = origin.shape[1] | |
a = np.zeros((2 * nbr_correspondences, 9)) | |
for i in range(nbr_correspondences): | |
x, y = origin[0][i], origin[1][i] | |
u, v = dest[0][i], dest[1][i] | |
a[2 * i] = [x, y, 1, 0, 0, 0, -u*x, -u*y, -u] | |
a[2 * i + 1] = [0, 0, 0, x, y, 1, -v*x, -v*y, -v] | |
u, s, v = np.linalg.svd(a, full_matrices=True) | |
homography_matrix = v[8].reshape(3, 3) | |
homography_matrix = np.dot(np.linalg.inv(c2), np.dot(homography_matrix, c1)) | |
homography_matrix = homography_matrix / homography_matrix[2, 2] | |
return homography_matrix | |
def get_homography_error(origin, dest, model, status=None): | |
assert model is not None | |
i = np.ones((len(origin), 1), dtype=np.float32) | |
if status is None: | |
status = i | |
origin = np.concatenate((origin, i), axis=1) | |
dest_fit = origin.dot(model.T) | |
dest_fit2 = dest_fit[:,:2] / dest_fit[:,2:3] | |
err_per_point = np.sum((dest[:,:2] - dest_fit2) ** 2, axis=1) | |
err_per_point = np.sqrt(err_per_point) * status | |
return err_per_point.mean() | |
class Homography: | |
def split(self, data): | |
return data[:,:3], data[:,3:] | |
def fit(self, data): | |
A, B = self.split(data) | |
H = calculate_homography_matrix(A.T, B.T) | |
return H | |
def get_error(self, data, model): | |
A, B = self.split(data) | |
B_fit = A.dot(model.T) | |
B2 = B[:,:2] | |
B2_fit = B_fit[:,:2] / B_fit[:, 2:3] | |
err_per_point = np.sum( (B2 - B2_fit) ** 2, axis=1) | |
err_per_point = np.sqrt(err_per_point) | |
return err_per_point | |
def find_homography(ptsA, ptsB, reprojThresh): | |
n = len(ptsA) | |
s = n//8 | |
k = 10 | |
d = 2*s | |
i = np.ones((len(ptsA), 1), dtype=np.float32) | |
data = np.concatenate((ptsA, i, ptsB, i), axis=1) | |
model = Homography() | |
ransac_fit, ransac_data, errors = ransac(data, model, s, k, reprojThresh, d) | |
return ransac_fit, ransac_data | |
class Stitcher: | |
def __init__(self): | |
# determine if we are using OpenCV v3.X | |
self.isv3 = imutils.is_cv3(or_better=True) | |
def stitch(self, images, ratio=0.75, reprojThresh=4.0, showMatches=False): | |
# unpack the images, then detect keypoints and extract | |
# local invariant descriptors from them | |
(imageB, imageA) = images | |
(kpsA, featuresA) = self.detectAndDescribe(imageA) | |
(kpsB, featuresB) = self.detectAndDescribe(imageB) | |
# match features between the two images | |
M = self.matchKeyPoints(kpsA, kpsB, | |
featuresA, featuresB, ratio, reprojThresh) | |
# if the match is None, then there aren't enough matched | |
# keypoints to create a panorama | |
if M is None: | |
return None | |
# otherwise, apply a perspective warp to stitch the images | |
# together | |
(matches, H, status) = M | |
result = cv2.warpPerspective(imageA, H, | |
(imageA.shape[1] + imageB.shape[1], imageA.shape[0])) | |
result[0:imageB.shape[0], 0:imageB.shape[1]] = imageB | |
# check to see if the keypoint matches should be visualized | |
if showMatches: | |
vis = self.drawMatches(imageA, imageB, kpsA, kpsB, matches, | |
status) | |
# return a tuple of the stitched image and the | |
# visualization | |
return (result, vis) | |
# return the stitched image | |
return result | |
def detectAndDescribe(self, image): | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
feat = cv2.AKAZE_create() | |
kps, features = feat.detectAndCompute(gray, None) | |
kps = np.float32([kp.pt for kp in kps]) | |
return kps, features | |
def matchKeyPoints(self, kpsA, kpsB, featuresA, featuresB, ratio, reprojThresh): | |
#bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) | |
#matches = matcher.match(featuresA, featuresB) | |
#dmatches = [m for m in matches if m.distance < max_distance_per_match] | |
bf = cv2.BFMatcher() | |
rawMatches = bf.knnMatch(featuresA, featuresB, k=2) | |
matches = [] | |
# loop over the raw matches | |
for m in rawMatches: | |
# ensure the distance is within a certain ratio of each | |
# other (i.e. Lowe's ratio test) | |
if len(m) == 2 and m[0].distance < m[1].distance * ratio: | |
matches.append((m[0].trainIdx, m[0].queryIdx)) | |
# computing a homography requires at least 4 matches | |
if len(matches) > 4: | |
# construct the two sets of points | |
ptsA = np.float32([kpsA[i] for (_, i) in matches]) | |
ptsB = np.float32([kpsB[i] for (i, _) in matches]) | |
# compute the homography between the two sets of points | |
(H, status) = cv2.findHomography(ptsA, ptsB, cv2.RANSAC, | |
reprojThresh) | |
errs = get_homography_error(ptsA, ptsB, H) | |
print('Homography errors: ', errs.mean()) | |
H, status = find_homography(ptsA, ptsB, reprojThresh) | |
errs = get_homography_error(ptsA, ptsB, H) | |
print('Custom Homography errors: ', errs.mean()) | |
# return the matches along with the homography matrix | |
# and status of each matched point | |
return (matches, H, status) | |
# otherwise, no homography could be computed | |
return None | |
def drawMatches(self, imageA, imageB, kpsA, kpsB, matches, status): | |
# initialize the output visualization image | |
(hA, wA) = imageA.shape[:2] | |
(hB, wB) = imageB.shape[:2] | |
vis = np.zeros((max(hA, hB), wA + wB, 3), dtype="uint8") | |
vis[0:hA, 0:wA] = imageA | |
vis[0:hB, wA:] = imageB | |
# loop over the matches | |
for ((trainIdx, queryIdx), s) in zip(matches, status): | |
# only process the match if the keypoint was successfully | |
# matched | |
if s == 1: | |
color = np.random.randint(0,255,3).tolist() | |
# draw the match | |
ptA = (int(kpsA[queryIdx][0]), int(kpsA[queryIdx][1])) | |
ptB = (int(kpsB[trainIdx][0]) + wA, int(kpsB[trainIdx][1])) | |
cv2.line(vis, ptA, ptB, color, 1) | |
# return the visualization | |
return vis | |
def test(): | |
left = cv2.imread('left.jpeg') | |
right = cv2.imread('right.jpeg') | |
print(left.shape) | |
result, vis = Stitcher().stitch([left, right], showMatches=True) | |
cv2.imshow('stitched', result[::2,::2]) | |
# naive = np.concatenate((left, right), axis=1) | |
# cv2.imshow('naive_stitch', naive[::2,::2]) | |
# cv2.imshow('matches', vis) | |
cv2.waitKey() | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment