-
-
Save zed/ac45fb5d117d8ecb66a3 to your computer and use it in GitHub Desktop.
discrete convolve vs. analytical
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
def convolve(y1, y2, dx=None): | |
''' | |
Compute the finite convolution of two signals of equal length. | |
@param y1: First signal. | |
@param y2: Second signal. | |
@param dx: [optional] Integration step width. | |
@note: Based on the algorithm at http://www.physics.rutgers.edu/~masud/computing/WPark_recipes_in_python.html. | |
''' | |
y1, y2 = map(np.asarray, [y1, y2]) | |
P = len(y1) #Determine the length of the signal | |
z = np.zeros_like(y1) #Create a list of convolution values | |
for k in range(P): | |
t = 0 | |
lower = max(0, k - (P - 1)) | |
upper = min(P - 1, k) | |
for i in range(lower, upper): | |
t += (y1[i] * y2[k - i] + y1[i + 1] * y2[k - (i + 1)]) / 2 | |
z[k] = t | |
if dx is not None: #Is a step width specified? | |
z *= dx | |
return z | |
steps = 50 #Number of integration steps | |
maxtime = 5 #Maximum time | |
dt = float(maxtime) / steps #Obtain the width of a time step | |
t = np.linspace(0, maxtime, steps, endpoint=False) #Create an array of times | |
exp1 = np.exp(-t) #Create an array of function values | |
exp2 = 2*np.exp(-2*t) | |
# Calculate the analytical expression | |
analytical = exp2*(np.exp(t)-1) | |
# Calculate the trapezoidal convolution | |
trapezoidal = convolve(exp1, exp2, dt) | |
# Calculate the scipy convolution | |
sci = np.convolve(exp1, exp2) | |
# Slice the first half to obtain the causal convolution and multiply | |
# by dt to account for the step width | |
sci = sci[:steps]*dt | |
# shift right, multiply | |
sci = np.r_[0,sci[:steps-1]]*0.86 | |
# Plot | |
plt.plot(t, analytical, label = 'analytical') | |
plt.plot(t, trapezoidal, 'o', label = 'trapezoidal') | |
plt.plot(t, sci, '.', label = 'numpy.convolve') | |
plt.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment