Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomverbeure/2dbe6457fc512e95032b8ba0d30aee92 to your computer and use it in GitHub Desktop.
Save tomverbeure/2dbe6457fc512e95032b8ba0d30aee92 to your computer and use it in GitHub Desktop.
FFT-based cross correlation
#! /usr/bin/env python3
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from filter_lib import *
from scipy import signal
def cross_correlation():
N = 256 # Number of samples of the pulse
freq = 10 # Frequency of the sine wave in the pulse
offset = 30 # Delay of TX pulse
delay = 129 # Delay of RX pulse after TX pulse
fft_decim = 1 # Sample drop ratio before doing FFT (1 -> nothing dropped)
np.random.seed(0)
# Create TX and RX pulse
pulse_x = np.linspace(0, 1, N, endpoint=False)
pulse_y = np.sin(2 * np.pi * freq * pulse_x)
pulse_y = pulse_y * np.cos(np.pi * (pulse_x-0.5))**2
tx = np.zeros(offset)
tx = np.append(tx, pulse_y)
tx = np.append(tx, np.zeros(1024-len(tx)) )
tx = tx + 0.05 * np.random.randn(len(tx)) # Make some noise!
rx = np.zeros(offset+delay)
rx = 0.6 * np.append(rx, pulse_y)
rx = np.append(rx, np.zeros(1024-len(rx)) )
rx = rx + 0.20 * np.random.randn(len(rx))
x = np.linspace(0, 1, len(tx), endpoint=False)
# Regular convolution based cross correlation
corr = signal.correlate(tx, rx, mode="same")
# FFT/IFFT based cross correlation
TX = np.fft.fft(tx[::fft_decim])
RX = np.fft.fft(rx[::fft_decim])
CORR = np.multiply(TX, np.conjugate(RX))
corr_fft = np.real(np.fft.ifft(CORR))
# Different location of max correlation between convolution-based and
# FFT-based cross correlation. Exercise for the reader to figure out why...
max_corr = np.argmax(corr)-len(tx)//2
max_corr_fft = (np.argmax(corr_fft) - len(corr_fft)) * fft_decim
# This should print the same result: -129 == delay between RX and TX
print("max corr:", max_corr)
print("max corr fft:", max_corr_fft)
nr_plots = 2
plot_nr = 0
plt.figure(figsize=(10, 8))
plot_nr += 1
plt.subplot(nr_plots, 1, plot_nr)
plt.gca().set_xlim([0.0, 1])
plt.gca().grid(True)
plt.plot(x, tx)
plt.plot(x, rx)
plt.title("Time Domain: Tx and Rx")
plot_nr += 1
plt.subplot(nr_plots, 1, plot_nr)
plt.gca().grid(True)
plt.gca().set_xlim([0.0, len(corr)])
plt.plot(np.arange(len(corr)),corr, ".-")
plt.plot(np.arange(len(corr_fft))*fft_decim,corr_fft, ".-")
plt.title("signal.correlate & FFT-based cross correlation")
plt.tight_layout()
plt.savefig("cross_correlation.svg")
plt.savefig("cross_correlation.png")
cross_correlation()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment