Skip to content

Instantly share code, notes, and snippets.

@tacaswell
Last active August 29, 2015 13:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tacaswell/8949218 to your computer and use it in GitHub Desktop.
Save tacaswell/8949218 to your computer and use it in GitHub Desktop.
class uneven_FFT_recon(object):
"""
A class to wrap up dealing with the irregularly sampled periodic
data
"""
@classmethod
def reconstruct_iterative(cls, phi, h, target_error,
start_N=3, bound_N=15, iters=25):
"""
reconstruct a band-limited a curve from unevenly
sampled points
Parameters
----------
phi : array
The sample points. Assumed to be in range [0, 2*np.pi)
h : array
The value of the curve
target_error : float
The average difference between the reconstruction
and the input data. Stops iterating when hit
start_N : int, optional
The initial bandwidth. Defaults to 3
bound_N : int, optional
The maximum bandwidth. Defaults to 15
iters : int, optional
The maximum number of refinement steps at each N
Returns:
reconstrution: uneven_FFT_recon
A callable object
"""
# internal details for interpolation
min_range = 0
max_range = 2*np.pi
_N = 1024
intep_func = scipy.interpolate.interp1d
new_phi = np.linspace(min_range, max_range, _N)
# make sure everything is ndarrays
phi = np.asarray(phi)
h = np.asarray(h)
# make sure phi is in range and sorted
phi = np.mod(phi, 2*np.pi)
indx = np.argsort(phi)
phi = phi[indx]
h = h[indx]
# add bounds to make the interpolation happy
pad = 1
_phi = np.hstack((phi[-pad::-1] - max_range,
phi,
phi[:pad:-1] + max_range))
_h = np.hstack((h[-pad::-1], h, h[:pad:-1]))
# do first FFT pass
f = intep_func(_phi, _h)
new_h = f(new_phi)
# / _N is to deal with the normalization scheme used by numpy
tmp_fft = np.fft.fft(new_h) / _N
max_N = start_N
# set this up once to save computation
sin_list = np.vstack([np.sin(k * phi)
for k in xrange(1, max_N+1)])
cos_list = np.vstack([np.cos(k * phi)
for k in xrange(1, max_N+1)])
not_done_flag = True
# truncate fft values
A_n = tmp_fft[:max_N+1]
while not_done_flag and max_N < bound_N:
# loop a bunch more times
j = 0
keep_going_flag = True
prev_err = None
while keep_going_flag and j < iters:
# pull out constant
A0 = A_n[0].real
# reshape the rest of the list for broadcasting tricks
A_list = A_n[1:].reshape(-1, 1)
# compute reconstruction via broadcasting + sum
re_con = 2*np.sum(A_list.real * cos_list -
A_list.imag * sin_list, axis=0) + A0
# compute error
error_h = h - re_con
mean_err = np.abs(np.mean(error_h))
# check if we are done for good
if mean_err < target_error:
not_done_flag = False
break
if prev_err is not None:
if (prev_err - mean_err) < .01 * target_error:
keep_going_flag = False
prev_err = mean_err
# add padding to error so that the interpolation is happy
_h = np.hstack((error_h[-pad::-1], error_h, error_h[:pad:-1]))
# compute FFT of error
f = intep_func(_phi, _h, kind='nearest')
new_h = f(new_phi)
tmp_fft = np.fft.fft(new_h) / _N
# add correction term to A_n list
A_n += tmp_fft[:max_N+1]
else:
tmp_an = np.zeros(max_N + 2, dtype='complex')
tmp_an[:max_N + 1] = A_n
max_N += 1
A_n = tmp_an
# set this up once to save computation
sin_list = np.vstack([np.sin(k * phi)
for k in xrange(1, max_N+1)])
cos_list = np.vstack([np.cos(k * phi)
for k in xrange(1, max_N+1)])
# create callable object with the reconstructed coefficients
return cls(A_n)
@classmethod
def reconstruct(cls, phi, h, max_N=10, iters=25):
"""
reconstruct a band-limited a curve from unevenly
sampled points
Parameters
----------
h : array
The value of the curve
phi : array
The sample points. Assumed to be in range [0, 2*np.pi)
max_N : int, optional
The number of modes to use
iters : int, optional
The number of iterations
Returns:
reconstrution: uneven_FFT_recon
A callable object
"""
# internal details for interpolation
min_range = 0
max_range = 2*np.pi
_N = 1024
intep_func = scipy.interpolate.interp1d
new_phi = np.linspace(min_range, max_range, _N)
# make sure everything is ndarrays
phi = np.asarray(phi)
h = np.asarray(h)
# make sure phi is in range and sorted
phi = np.mod(phi, max_range)
indx = np.argsort(phi)
phi = phi[indx]
h = h[indx]
# set this up once to save computation
sin_list = np.vstack([np.sin(k * phi)
for k in xrange(1, max_N+1)])
cos_list = np.vstack([np.cos(k * phi)
for k in xrange(1, max_N+1)])
# add bounds to make the interpolation happy
pad = 2
_phi = np.hstack((phi[-pad::-1] - max_range,
phi,
phi[:pad:-1] + max_range))
_h = np.hstack((h[-pad::-1], h, h[:pad:-1]))
# do first FFT pass
f = intep_func(_phi, _h)
new_h = f(new_phi)
# / _N is to deal with the normalization scheme used by numpy
tmp_fft = np.fft.fft(new_h) / _N
# truncate fft values
A_n = tmp_fft[:max_N+1]
prev_err = None
# loop a bunch more times
for j in range(iters):
# pull out constant
A0 = A_n[0].real
# reshape the rest of the list for broadcasting tricks
A_list = A_n[1:].reshape(-1, 1)
# compute reconstruction via broadcasting + sum
re_con = 2*np.sum(A_list.real * cos_list -
A_list.imag * sin_list, axis=0) + A0
# compute error
error_h = h - re_con
curr_err = np.abs(np.mean(error_h))
if prev_err is not None:
# if the error isn't really improving, bail
if 1 - curr_err / prev_err < .001:
break
prev_err = curr_err
# add padding to error so that the interpolation is happy
_h = np.hstack((error_h[-pad::-1], error_h, error_h[:pad:-1]))
# compute FFT of error
f = intep_func(_phi, _h, kind='nearest')
new_h = f(new_phi)
tmp_fft = np.fft.fft(new_h) / _N
# add correction term to A_n list
A_n += tmp_fft[:max_N+1]
# create callable object with the reconstructed coefficients
return cls(A_n)
def __init__(self, A_n):
self._A_n = np.asarray(A_n)
self.N = len(A_n)
@property
def A_n(self):
return self._A_n
@property
def max_N(self):
return len(self.A_n) - 1
def __call__(self, th, deriv=0):
"""
Returns the re-constructed curve at the points
`th`.
The order of the derivative is specified by deriv.
"""
th = np.asarray(th)
sin_list = np.vstack([k**deriv * np.sin(k * th)
for k in xrange(1, self.N)])
cos_list = np.vstack([k**deriv * np.cos(k * th)
for k in xrange(1, self.N)])
A0 = self._A_n[0].real if deriv == 0 else 0
A_list = self._A_n[1:].reshape(-1, 1)
# there has to be a better way to write this
deriv = deriv % 4
if deriv == 0:
a, b = cos_list, sin_list
elif deriv == 1:
a, b = -sin_list, cos_list
elif deriv == 2:
a, b = -cos_list, -sin_list
elif deriv == 3:
a, b = sin_list, -cos_list
return 2*np.sum(A_list.real * a -
A_list.imag * b, axis=0) + A0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment