Skip to content

Instantly share code, notes, and snippets.

@dinomono dinomono/stitch.py Secret
Created Mar 14, 2018

Embed
What would you like to do?
Example of image stitching algorithm
#!/usr/bin/python
import os
import sys
import cv2
import math
import numpy as np
import utils
from numpy import linalg
class Stitch(object):
def __init__(self, image_dir, key_frame, output_dir, img_filter=None):
'''
image_dir: 'directory' containing all images
key_frame: 'dir/name.jpg' of the base image
output_dir: 'directory' where to save output images
optional:
img_filter = 'JPG'; None->Take all images
'''
self.key_frame_file = os.path.split(key_frame)[-1]
self.output_dir = output_dir
# Open the directory given in the arguments
self.dir_list = []
try:
self.dir_list = os.listdir(image_dir)
if img_filter:
# remove all files that doen't end with .[image_filter]
self.dir_list = filter(lambda x: x.find(img_filter) > -1, self.dir_list)
try: #remove Thumbs.db, is existent (windows only)
self.dir_list.remove('.DS_Store')
except ValueError:
pass
except:
print >> sys.stderr, ("Unable to open directory: %s" % image_dir)
sys.exit(-1)
self.dir_list = map(lambda x: os.path.join(image_dir, x), self.dir_list)
self.dir_list = filter(lambda x: x != key_frame, self.dir_list)
base_img_rgb = cv2.imread(key_frame)
if base_img_rgb == None:
raise IOError("%s doesn't exist" %key_frame)
final_img = self.stitch(base_img_rgb, 0)
def filter_matches(self, matches, ratio = 0.75):
filtered_matches = []
for m in matches:
if len(m) == 2 and m[0].distance < m[1].distance * ratio:
filtered_matches.append(m[0])
return filtered_matches
def imageDistance(self, matches):
sumDistance = 0.0
for match in matches:
sumDistance += match.distance
return sumDistance
def findDimensions(self, image, homography):
base_p1 = np.ones(3, np.float32)
base_p2 = np.ones(3, np.float32)
base_p3 = np.ones(3, np.float32)
base_p4 = np.ones(3, np.float32)
(y, x) = image.shape[:2]
base_p1[:2] = [0,0]
base_p2[:2] = [x,0]
base_p3[:2] = [0,y]
base_p4[:2] = [x,y]
max_x = None
max_y = None
min_x = None
min_y = None
for pt in [base_p1, base_p2, base_p3, base_p4]:
hp = np.matrix(homography, np.float32) * np.matrix(pt, np.float32).T
hp_arr = np.array(hp, np.float32)
normal_pt = np.array([hp_arr[0]/hp_arr[2], hp_arr[1]/hp_arr[2]], np.float32)
if ( max_x == None or normal_pt[0,0] > max_x ):
max_x = normal_pt[0,0]
if ( max_y == None or normal_pt[1,0] > max_y ):
max_y = normal_pt[1,0]
if ( min_x == None or normal_pt[0,0] < min_x ):
min_x = normal_pt[0,0]
if ( min_y == None or normal_pt[1,0] < min_y ):
min_y = normal_pt[1,0]
min_x = min(0, min_x)
min_y = min(0, min_y)
return (min_x, min_y, max_x, max_y)
def stitch(self, base_img_rgb, round=0):
if ( len(self.dir_list) < 1 ):
return base_img_rgb
base_img = cv2.GaussianBlur(cv2.cvtColor(base_img_rgb,cv2.COLOR_BGR2GRAY), (5,5), 0)
# Use the SIFT feature detector
detector = cv2.xfeatures2d.SIFT_create()
# Find key points in base image for motion estimation
base_features, base_descs = detector.detectAndCompute(base_img, None)
# Parameters for nearest-neighbor matching
FLANN_INDEX_KDTREE = 1
flann_params = dict(algorithm = FLANN_INDEX_KDTREE,
trees = 5)
matcher = cv2.FlannBasedMatcher(flann_params, {})
print "Iterating through next images..."
closestImage = None
next_img_path = self.dir_list[0]
print next_img_path
print "Reading %s..." % next_img_path
# Read in the next image...
next_img_rgb = cv2.imread(next_img_path)
next_img = cv2.GaussianBlur(cv2.cvtColor(next_img_rgb,cv2.COLOR_BGR2GRAY), (5,5), 0)
print "\t Finding points..."
# Find points in the next frame
next_features, next_descs = detector.detectAndCompute(next_img, None)
matches = matcher.knnMatch(next_descs, trainDescriptors=base_descs, k=2)
print "\t Match Count: ", len(matches)
matches_subset = self.filter_matches(matches)
print "\t Filtered Match Count: ", len(matches_subset)
distance = self.imageDistance(matches_subset)
print "\t Distance from Key Image: ", distance
averagePointDistance = distance/float(len(matches_subset))
print "\t Average Distance: ", averagePointDistance
kp1 = []
kp2 = []
for match in matches_subset:
kp1.append(base_features[match.trainIdx])
kp2.append(next_features[match.queryIdx])
p1 = np.array([k.pt for k in kp1])
p2 = np.array([k.pt for k in kp2])
H, status = cv2.findHomography(p1, p2, cv2.RANSAC, 5.0)
print '%d / %d inliers/matched' % (np.sum(status), len(status))
inlierRatio = float(np.sum(status)) / float(len(status))
if ( closestImage == None or inlierRatio > closestImage['inliers'] ):
closestImage = {}
closestImage['h'] = H
closestImage['inliers'] = inlierRatio
closestImage['dist'] = averagePointDistance
closestImage['path'] = next_img_path
closestImage['rgb'] = next_img_rgb
closestImage['img'] = next_img
closestImage['feat'] = next_features
closestImage['desc'] = next_descs
closestImage['match'] = matches_subset
print "Closest Image: ", closestImage['path']
print "Closest Image Ratio: ", closestImage['inliers']
self.dir_list = filter(lambda x: x != closestImage['path'], self.dir_list)
H = closestImage['h']
H = H / H[2,2]
H_inv = linalg.inv(H)
if ( closestImage['inliers'] > 0.1 ): # and
(min_x, min_y, max_x, max_y) = self.findDimensions(closestImage['img'], H_inv)
# Adjust max_x and max_y by base img size
max_x = max(max_x, base_img.shape[1])
max_y = max(max_y, base_img.shape[0])
move_h = np.matrix(np.identity(3), np.float32)
if ( min_x < 0 ):
move_h[0,2] += -min_x
max_x += -min_x
if ( min_y < 0 ):
move_h[1,2] += -min_y
max_y += -min_y
print "Homography: \n", H
print "Inverse Homography: \n", H_inv
print "Min Points: ", (min_x, min_y)
mod_inv_h = move_h * H_inv
img_w = int(math.ceil(max_x))
img_h = int(math.ceil(max_y))
print "New Dimensions: ", (img_w, img_h)
# crop edges
print "Cropping..."
base_h, base_w, base_d = base_img_rgb.shape
next_h, next_w, next_d = closestImage['rgb'].shape
base_img_rgb = base_img_rgb[5:(base_h-5),5:(base_w-5)]
closestImage['rgb'] = closestImage['rgb'][5:(next_h-5),5:(next_w-5)]
# Warp the new image given the homography from the old image
base_img_warp = cv2.warpPerspective(base_img_rgb, move_h, (img_w, img_h))
print "Warped base image"
# utils.showImage(base_img_warp, scale=(0.2, 0.2), timeout=1000, save=True, title="base_img_warp")
# cv2.destroyAllWindows()
next_img_warp = cv2.warpPerspective(closestImage['rgb'], mod_inv_h, (img_w, img_h))
print "Warped next image"
# utils.showImage(next_img_warp, scale=(0.2, 0.2), timeout=1000, save=True, title="next_img_warp")
# cv2.destroyAllWindows()
# Put the base image on an enlarged palette
enlarged_base_img = np.zeros((img_h, img_w, 3), np.uint8)
print "Enlarged Image Shape: ", enlarged_base_img.shape
print "Base Image Shape: ", base_img_rgb.shape
print "Base Image Warp Shape: ", base_img_warp.shape
# enlarged_base_img[y:y+base_img_rgb.shape[0],x:x+base_img_rgb.shape[1]] = base_img_rgb
# enlarged_base_img[:base_img_warp.shape[0],:base_img_warp.shape[1]] = base_img_warp
# Create masked composite
(ret,data_map) = cv2.threshold(cv2.cvtColor(next_img_warp, cv2.COLOR_BGR2GRAY),
0, 255, cv2.THRESH_BINARY)
# add base image
enlarged_base_img = cv2.add(enlarged_base_img, base_img_warp,
mask=np.bitwise_not(data_map),
dtype=cv2.CV_8U)
# add next image
final_img = cv2.add(enlarged_base_img, next_img_warp,
dtype=cv2.CV_8U)
# Crop black edge
final_gray = cv2.cvtColor(final_img, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(final_gray, 1, 255, cv2.THRESH_BINARY)
dino, contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
print "Found %d contours..." % (len(contours))
max_area = 0
best_rect = (0,0,0,0)
for cnt in contours:
x,y,w,h = cv2.boundingRect(cnt)
deltaHeight = h-y
deltaWidth = w-x
area = deltaHeight * deltaWidth
if ( area > max_area and deltaHeight > 0 and deltaWidth > 0):
max_area = area
best_rect = (x,y,w,h)
if ( max_area > 0 ):
final_img_crop = final_img[best_rect[1]:best_rect[1]+best_rect[3],
best_rect[0]:best_rect[0]+best_rect[2]]
final_img = final_img_crop
# output
final_filename = "%s/%d.JPG" % (self.output_dir, round)
cv2.imwrite(final_filename, final_img)
return self.stitch(final_img, round+1)
else:
return self.stitchImages(base_img_rgb, round+1)
# ----------------------------------------------------------------------------
if __name__ == '__main__':
if ( len(sys.argv) < 4 ):
print >> sys.stderr, ("Usage: %s <image_dir> <key_frame> <output>" % sys.argv[0])
sys.exit(-1)
Stitch(sys.argv[1], sys.argv[2], sys.argv[3])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.