Created
March 15, 2019 06:01
-
-
Save Daiver/be0d03b219b2ff85f1f9d03d51e682ff to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import torch | |
import torch.nn.functional as F | |
import scipy | |
import numpy as np | |
def draw_circle(canvas, point, radius): | |
cv2.circle(canvas, (int(round(point[0])), int(round(point[1]))), radius, color=255, thickness=-1) | |
def img_grad_col(img): | |
img = img.unsqueeze(0).unsqueeze(0) | |
a = torch.Tensor([[1, 0, -1], | |
[2, 0, -2], | |
[1, 0, -1]]) | |
a = a.view((1, 1, 3, 3)) | |
g_x = F.conv2d(img, a) | |
g_x = g_x.squeeze() | |
return g_x | |
def img_grad_row(img): | |
img = img.unsqueeze(0).unsqueeze(0) | |
b = torch.Tensor([[1, 2, 1], | |
[0, 0, 0], | |
[-1, -2, -1]]) | |
b = b.view((1, 1, 3, 3)) | |
g_y = F.conv2d(img, b) | |
g_y = g_y.squeeze() | |
return g_y | |
def theta_for_patch_center(img_shape, window_size, patch_center): | |
scale_x = torch.tensor(window_size[1] / img_shape[1]) | |
scale_y = torch.tensor(window_size[0] / img_shape[0]) | |
offset_x = 2 * patch_center[1] / (img_shape[1] - 1) - 1 | |
offset_y = 2 * patch_center[0] / (img_shape[0] - 1) - 1 | |
zero_tensor = torch.tensor(0.0) | |
theta = torch.stack([ | |
scale_x, zero_tensor, offset_x, | |
zero_tensor, scale_y, offset_y | |
]).view(1, 2, 3) | |
return theta | |
def cut_patch(img, window_size, patch_center, inter_mode="bilinear"): | |
theta = theta_for_patch_center(img.shape, window_size, patch_center) | |
grid = F.affine_grid(theta, [1, 1, window_size[0], window_size[1]]) | |
img = img.unsqueeze(0).unsqueeze(0) | |
sampled = F.grid_sample(img, grid, mode=inter_mode) | |
sampled = sampled.squeeze() | |
return sampled | |
def compute_lk_error(frame0, frame1, window_size, patch_center, p): | |
vals0 = cut_patch(frame0, window_size, patch_center) | |
vals1 = cut_patch(frame1, window_size, patch_center + p) | |
vals0 = vals0.view(-1, 1) | |
vals1 = vals1.view(-1, 1) | |
diff = vals0 - vals1 | |
return diff.transpose(0, 1) @ diff | |
def compute_jacobian(frame0_x, frame0_y, window_size, patch_center): | |
dxs = cut_patch(frame0_x, window_size, patch_center).view(-1) | |
dys = cut_patch(frame0_y, window_size, patch_center).view(-1) | |
return torch.stack((dxs, dys), dim=1) | |
def perform_lk(frame0, frame1, window_size, patch_center, p0): | |
p = p0.clone() | |
print(p) | |
frame0_r = img_grad_row(frame0) | |
frame0_c = img_grad_col(frame0) | |
jacobian = compute_jacobian(frame0_r, frame0_c, window_size, patch_center) | |
hessian = jacobian.transpose(0, 1) @ jacobian | |
hessian_inv = hessian.inverse() | |
n_iters = 200 | |
for iter_ind in range(n_iters): | |
err = compute_lk_error(frame0, frame1, window_size, patch_center, p) | |
print(f"{iter_ind + 1}/{n_iters} Loss = {err}") | |
if err < 1e-3: | |
break | |
vals0 = cut_patch(frame0, window_size, patch_center).view(-1, 1) | |
vals1 = cut_patch(frame1, window_size, patch_center + p).view(-1, 1) | |
grad = jacobian.transpose(0, 1) @ (vals0 - vals1) | |
dp = hessian_inv @ grad | |
p -= dp.view(2) | |
print(p) | |
return p | |
def main(): | |
canvas1 = np.zeros((128, 128), dtype=np.uint8) | |
canvas2 = np.zeros((128, 128), dtype=np.uint8) | |
# canvas1 = np.zeros((256, 256), dtype=np.uint8) | |
# canvas2 = np.zeros((256, 256), dtype=np.uint8) | |
draw_circle(canvas1, (64, 64), 5) | |
draw_circle(canvas2, (66, 64), 5) | |
canvas1_torch = torch.FloatTensor(canvas1) | |
canvas2_torch = torch.FloatTensor(canvas2) | |
window_size = (11, 11) | |
patch_center = torch.FloatTensor([64, 64]) | |
patch_center.requires_grad_(True) | |
p = torch.FloatTensor([-1, 0]) | |
print(compute_lk_error(canvas1_torch, canvas2_torch, window_size, patch_center, p)) | |
print(perform_lk(canvas1_torch, canvas2_torch, window_size, patch_center, p)) | |
# cv2.imshow('', canvas1) | |
# cv2.imshow('r', canvas1_r) | |
# cv2.imshow('c', canvas1_c) | |
# cv2.waitKey() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment