Created
March 2, 2021 17:47
-
-
Save tomverbeure/2dbe6457fc512e95032b8ba0d30aee92 to your computer and use it in GitHub Desktop.
FFT-based cross correlation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
import numpy as np | |
import matplotlib | |
from matplotlib import pyplot as plt | |
from filter_lib import * | |
from scipy import signal | |
def cross_correlation(): | |
N = 256 # Number of samples of the pulse | |
freq = 10 # Frequency of the sine wave in the pulse | |
offset = 30 # Delay of TX pulse | |
delay = 129 # Delay of RX pulse after TX pulse | |
fft_decim = 1 # Sample drop ratio before doing FFT (1 -> nothing dropped) | |
np.random.seed(0) | |
# Create TX and RX pulse | |
pulse_x = np.linspace(0, 1, N, endpoint=False) | |
pulse_y = np.sin(2 * np.pi * freq * pulse_x) | |
pulse_y = pulse_y * np.cos(np.pi * (pulse_x-0.5))**2 | |
tx = np.zeros(offset) | |
tx = np.append(tx, pulse_y) | |
tx = np.append(tx, np.zeros(1024-len(tx)) ) | |
tx = tx + 0.05 * np.random.randn(len(tx)) # Make some noise! | |
rx = np.zeros(offset+delay) | |
rx = 0.6 * np.append(rx, pulse_y) | |
rx = np.append(rx, np.zeros(1024-len(rx)) ) | |
rx = rx + 0.20 * np.random.randn(len(rx)) | |
x = np.linspace(0, 1, len(tx), endpoint=False) | |
# Regular convolution based cross correlation | |
corr = signal.correlate(tx, rx, mode="same") | |
# FFT/IFFT based cross correlation | |
TX = np.fft.fft(tx[::fft_decim]) | |
RX = np.fft.fft(rx[::fft_decim]) | |
CORR = np.multiply(TX, np.conjugate(RX)) | |
corr_fft = np.real(np.fft.ifft(CORR)) | |
# Different location of max correlation between convolution-based and | |
# FFT-based cross correlation. Exercise for the reader to figure out why... | |
max_corr = np.argmax(corr)-len(tx)//2 | |
max_corr_fft = (np.argmax(corr_fft) - len(corr_fft)) * fft_decim | |
# This should print the same result: -129 == delay between RX and TX | |
print("max corr:", max_corr) | |
print("max corr fft:", max_corr_fft) | |
nr_plots = 2 | |
plot_nr = 0 | |
plt.figure(figsize=(10, 8)) | |
plot_nr += 1 | |
plt.subplot(nr_plots, 1, plot_nr) | |
plt.gca().set_xlim([0.0, 1]) | |
plt.gca().grid(True) | |
plt.plot(x, tx) | |
plt.plot(x, rx) | |
plt.title("Time Domain: Tx and Rx") | |
plot_nr += 1 | |
plt.subplot(nr_plots, 1, plot_nr) | |
plt.gca().grid(True) | |
plt.gca().set_xlim([0.0, len(corr)]) | |
plt.plot(np.arange(len(corr)),corr, ".-") | |
plt.plot(np.arange(len(corr_fft))*fft_decim,corr_fft, ".-") | |
plt.title("signal.correlate & FFT-based cross correlation") | |
plt.tight_layout() | |
plt.savefig("cross_correlation.svg") | |
plt.savefig("cross_correlation.png") | |
cross_correlation() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment