Skip to content

Instantly share code, notes, and snippets.

@Nullkooland
Last active August 30, 2021 14:23
Show Gist options
  • Save Nullkooland/53f1157407c84df64f669a4aa43fe375 to your computer and use it in GitHub Desktop.
Save Nullkooland/53f1157407c84df64f669a4aa43fe375 to your computer and use it in GitHub Desktop.
PyTorch implementation of 2d normalized cross-correlation (NCC)
import torch
import torch.nn as nn
import torch.nn.functional as F
class NormalizedCrossCorrelation2d(nn.Module):
def __init__(self):
super(NormalizedCrossCorrelation2d, self).__init__()
def forward(self, x: torch.Tensor, template: torch.Tensor) -> torch.Tensor:
# Do image-wise normalization
x = F.instance_norm(x)
template = F.instance_norm(template)
# Extract NC dimensions
batches, channels, hi, wi = x.shape
ht, wt = template.shape[2:]
groups = channels * batches
# Reshape templeate tensor to correlation kernel (flatten NC dimensions)
input = x.view(1, groups, hi, wi) # -> [1, N*C, Hi, Wi]
kernel = template.view(groups, 1, ht, wt) # -> [N*C, 1, Ht, Wt]
# Do correlation and reshape to NCHW output
xcorr = F.conv2d(input, kernel, groups=groups) # -> [1, N*C, Ho, Wo]
ho, wo = xcorr.shape[2:]
return xcorr.view(batches, channels, ho, wo) # -> [N, C, Ho, Wo]
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
if __name__ == "__main__":
# Load test image
img = cv2.imread("data/lena.png", cv2.IMREAD_GRAYSCALE)
img = torch.tensor(img, dtype=torch.float32) / 255
roi = img[ROI[1]:ROI[1]+ROI[3], ROI[0]:ROI[0]+ROI[2]]
# NCHW
a = torch.tile(img, (2, 4, 1, 1))
b = torch.tile(roi, (2, 4, 1, 1))
# Calculate NCC
ncc = NormalizedCrossCorrelation2d()
xcorr = ncc(a, b)
print(xcorr.shape)
# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(img, vmin=0, vmax=1)
axes[1].imshow(roi, vmin=0, vmax=1)
axes[2].imshow(xcorr[1, 2, ...], cmap="inferno")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment