Skip to content

Instantly share code, notes, and snippets.

@etienne87
Last active June 23, 2019 10:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save etienne87/257e1acd5875f0813fc0665acf75087a to your computer and use it in GitHub Desktop.
Save etienne87/257e1acd5875f0813fc0665acf75087a to your computer and use it in GitHub Desktop.
image stitching experiment
from __future__ import print_function
import numpy as np
import imutils
import cv2
"""
Homecooked RANSAC, we will see if we manage to replace cv2.findHomography(..., cv2.RANSAC)
"""
def ransac(data, model, n, k, t, d):
bestfit = None
besterr = np.inf
best_inlier_idxs = None
status = None
idx = np.arange(len(data))
for i in range(k):
jdx = np.random.choice(idx, size=n, replace=True)
maybeinliers = data[jdx]
maybemodel = model.fit(maybeinliers)
all_errors = model.get_error(data, maybemodel)
inliers_idxs = idx[all_errors < t]
#print('all errors: ', all_errors.min(), all_errors.max(), ' len inliers: ', len(inliers_idxs), ' thresh: ', t)
if len(inliers_idxs) > d:
betterdata = data[inliers_idxs]
bettermodel = model.fit(betterdata)
better_errs = model.get_error(betterdata, bettermodel)
thiserr = np.mean(better_errs)
if thiserr < besterr:
#print('better error: ', model.get_error(data, bettermodel).mean())
bestfit = bettermodel
besterr = thiserr
best_inlier_idxs = inliers_idxs
status = all_errors < t
return bestfit, status, all_errors
def normalize(point_list, point_axis=1):
m = np.mean(point_list[:2], axis=point_axis)
#max_std = min(np.std(point_list[:2], axis=point_axis)) + 1e-9
std = np.std(point_list[:2], axis=point_axis) + 1e-9
c = np.diag([1.0 / std[0], 1.0 / std[1], 1])
c[0][2] = -m[0] / std[0]
c[1][2] = -m[1] / std[1]
return np.dot(c, point_list), c
def calculate_homography_matrix(origin, dest):
"""
9 point algorithm.
------------------
See http://www.cse.psu.edu/~rtc12/CSE486/lecture16.pdf
1. normalize points, get normalization matrices
2. by expressing x' = H_row1.dot(x)/H_row3.dot(x); y' = H_row2.dot(x)/H_row3.dot(x)
we get 2 equations per match (pair origin/ dest)
3. solve homogeneous equation Ax = 0 using SVD decomposition.
:param origin:
:param dest:
:return:
"""
origin, c1 = normalize(origin)
dest, c2 = normalize(dest)
nbr_correspondences = origin.shape[1]
a = np.zeros((2 * nbr_correspondences, 9))
for i in range(nbr_correspondences):
x, y = origin[0][i], origin[1][i]
u, v = dest[0][i], dest[1][i]
a[2 * i] = [x, y, 1, 0, 0, 0, -u*x, -u*y, -u]
a[2 * i + 1] = [0, 0, 0, x, y, 1, -v*x, -v*y, -v]
u, s, v = np.linalg.svd(a, full_matrices=True)
homography_matrix = v[8].reshape(3, 3)
homography_matrix = np.dot(np.linalg.inv(c2), np.dot(homography_matrix, c1))
homography_matrix = homography_matrix / homography_matrix[2, 2]
return homography_matrix
def get_homography_error(origin, dest, model, status=None):
assert model is not None
i = np.ones((len(origin), 1), dtype=np.float32)
if status is None:
status = i
origin = np.concatenate((origin, i), axis=1)
dest_fit = origin.dot(model.T)
dest_fit2 = dest_fit[:,:2] / dest_fit[:,2:3]
err_per_point = np.sum((dest[:,:2] - dest_fit2) ** 2, axis=1)
err_per_point = np.sqrt(err_per_point) * status
return err_per_point.mean()
class Homography:
def split(self, data):
return data[:,:3], data[:,3:]
def fit(self, data):
A, B = self.split(data)
H = calculate_homography_matrix(A.T, B.T)
return H
def get_error(self, data, model):
A, B = self.split(data)
B_fit = A.dot(model.T)
B2 = B[:,:2]
B2_fit = B_fit[:,:2] / B_fit[:, 2:3]
err_per_point = np.sum( (B2 - B2_fit) ** 2, axis=1)
err_per_point = np.sqrt(err_per_point)
return err_per_point
def find_homography(ptsA, ptsB, reprojThresh):
n = len(ptsA)
s = n//8
k = 10
d = 2*s
i = np.ones((len(ptsA), 1), dtype=np.float32)
data = np.concatenate((ptsA, i, ptsB, i), axis=1)
model = Homography()
ransac_fit, ransac_data, errors = ransac(data, model, s, k, reprojThresh, d)
return ransac_fit, ransac_data
class Stitcher:
def __init__(self):
# determine if we are using OpenCV v3.X
self.isv3 = imutils.is_cv3(or_better=True)
def stitch(self, images, ratio=0.75, reprojThresh=4.0, showMatches=False):
# unpack the images, then detect keypoints and extract
# local invariant descriptors from them
(imageB, imageA) = images
(kpsA, featuresA) = self.detectAndDescribe(imageA)
(kpsB, featuresB) = self.detectAndDescribe(imageB)
# match features between the two images
M = self.matchKeyPoints(kpsA, kpsB,
featuresA, featuresB, ratio, reprojThresh)
# if the match is None, then there aren't enough matched
# keypoints to create a panorama
if M is None:
return None
# otherwise, apply a perspective warp to stitch the images
# together
(matches, H, status) = M
result = cv2.warpPerspective(imageA, H,
(imageA.shape[1] + imageB.shape[1], imageA.shape[0]))
result[0:imageB.shape[0], 0:imageB.shape[1]] = imageB
# check to see if the keypoint matches should be visualized
if showMatches:
vis = self.drawMatches(imageA, imageB, kpsA, kpsB, matches,
status)
# return a tuple of the stitched image and the
# visualization
return (result, vis)
# return the stitched image
return result
def detectAndDescribe(self, image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
feat = cv2.AKAZE_create()
kps, features = feat.detectAndCompute(gray, None)
kps = np.float32([kp.pt for kp in kps])
return kps, features
def matchKeyPoints(self, kpsA, kpsB, featuresA, featuresB, ratio, reprojThresh):
#bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
#matches = matcher.match(featuresA, featuresB)
#dmatches = [m for m in matches if m.distance < max_distance_per_match]
bf = cv2.BFMatcher()
rawMatches = bf.knnMatch(featuresA, featuresB, k=2)
matches = []
# loop over the raw matches
for m in rawMatches:
# ensure the distance is within a certain ratio of each
# other (i.e. Lowe's ratio test)
if len(m) == 2 and m[0].distance < m[1].distance * ratio:
matches.append((m[0].trainIdx, m[0].queryIdx))
# computing a homography requires at least 4 matches
if len(matches) > 4:
# construct the two sets of points
ptsA = np.float32([kpsA[i] for (_, i) in matches])
ptsB = np.float32([kpsB[i] for (i, _) in matches])
# compute the homography between the two sets of points
(H, status) = cv2.findHomography(ptsA, ptsB, cv2.RANSAC,
reprojThresh)
errs = get_homography_error(ptsA, ptsB, H)
print('Homography errors: ', errs.mean())
H, status = find_homography(ptsA, ptsB, reprojThresh)
errs = get_homography_error(ptsA, ptsB, H)
print('Custom Homography errors: ', errs.mean())
# return the matches along with the homography matrix
# and status of each matched point
return (matches, H, status)
# otherwise, no homography could be computed
return None
def drawMatches(self, imageA, imageB, kpsA, kpsB, matches, status):
# initialize the output visualization image
(hA, wA) = imageA.shape[:2]
(hB, wB) = imageB.shape[:2]
vis = np.zeros((max(hA, hB), wA + wB, 3), dtype="uint8")
vis[0:hA, 0:wA] = imageA
vis[0:hB, wA:] = imageB
# loop over the matches
for ((trainIdx, queryIdx), s) in zip(matches, status):
# only process the match if the keypoint was successfully
# matched
if s == 1:
color = np.random.randint(0,255,3).tolist()
# draw the match
ptA = (int(kpsA[queryIdx][0]), int(kpsA[queryIdx][1]))
ptB = (int(kpsB[trainIdx][0]) + wA, int(kpsB[trainIdx][1]))
cv2.line(vis, ptA, ptB, color, 1)
# return the visualization
return vis
def test():
left = cv2.imread('left.jpeg')
right = cv2.imread('right.jpeg')
print(left.shape)
result, vis = Stitcher().stitch([left, right], showMatches=True)
cv2.imshow('stitched', result[::2,::2])
# naive = np.concatenate((left, right), axis=1)
# cv2.imshow('naive_stitch', naive[::2,::2])
# cv2.imshow('matches', vis)
cv2.waitKey()
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment