Last active
August 30, 2021 14:23
-
-
Save Nullkooland/53f1157407c84df64f669a4aa43fe375 to your computer and use it in GitHub Desktop.
PyTorch implementation of 2d normalized cross-correlation (NCC)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class NormalizedCrossCorrelation2d(nn.Module): | |
def __init__(self): | |
super(NormalizedCrossCorrelation2d, self).__init__() | |
def forward(self, x: torch.Tensor, template: torch.Tensor) -> torch.Tensor: | |
# Do image-wise normalization | |
x = F.instance_norm(x) | |
template = F.instance_norm(template) | |
# Extract NC dimensions | |
batches, channels, hi, wi = x.shape | |
ht, wt = template.shape[2:] | |
groups = channels * batches | |
# Reshape templeate tensor to correlation kernel (flatten NC dimensions) | |
input = x.view(1, groups, hi, wi) # -> [1, N*C, Hi, Wi] | |
kernel = template.view(groups, 1, ht, wt) # -> [N*C, 1, Ht, Wt] | |
# Do correlation and reshape to NCHW output | |
xcorr = F.conv2d(input, kernel, groups=groups) # -> [1, N*C, Ho, Wo] | |
ho, wo = xcorr.shape[2:] | |
return xcorr.view(batches, channels, ho, wo) # -> [N, C, Ho, Wo] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
if __name__ == "__main__": | |
# Load test image | |
img = cv2.imread("data/lena.png", cv2.IMREAD_GRAYSCALE) | |
img = torch.tensor(img, dtype=torch.float32) / 255 | |
roi = img[ROI[1]:ROI[1]+ROI[3], ROI[0]:ROI[0]+ROI[2]] | |
# NCHW | |
a = torch.tile(img, (2, 4, 1, 1)) | |
b = torch.tile(roi, (2, 4, 1, 1)) | |
# Calculate NCC | |
ncc = NormalizedCrossCorrelation2d() | |
xcorr = ncc(a, b) | |
print(xcorr.shape) | |
# Visualize results | |
fig, axes = plt.subplots(1, 3, figsize=(12, 4)) | |
axes[0].imshow(img, vmin=0, vmax=1) | |
axes[1].imshow(roi, vmin=0, vmax=1) | |
axes[2].imshow(xcorr[1, 2, ...], cmap="inferno") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment