Created
October 10, 2022 05:57
-
-
Save jmmshn/56aa94c4bd0be1d9c6c48128106d3819 to your computer and use it in GitHub Desktop.
Fourier Interpolation python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import numpy.typing as npt | |
def pad_arr(arr_in: npt.NDArray, shape: list[int]) -> npt.NDArray: | |
"""Pad a function on a hypercube. | |
Parameters | |
---------- | |
arr_in: | |
Data to be padded with zeros | |
shape: | |
Desired shape of the array | |
Returns | |
------- | |
NDArray: | |
padded data | |
""" | |
def get_slice(idig, idim, bound_pairs): | |
if idig == "0": | |
return slice(0, bound_pairs[idim][0]) | |
elif idig == "1": | |
return slice(-bound_pairs[idim][1], None) | |
else: | |
raise ValueError("Binary digit not 1 or 0") | |
dimensions = arr_in.shape | |
boundaries = [ | |
( | |
int(np.ceil(min(i_dim, j_dim) + 1) / 2.0), | |
int(np.floor(min(i_dim, j_dim)) / 2.0), | |
) | |
for i_dim, j_dim in zip(dimensions, shape) | |
] | |
dim = len(dimensions) | |
fmt = f"#0{dim+2}b" | |
corners = [format(itr, fmt)[-dim:] for itr in range(2**dim)] | |
arr_out = np.zeros(shape, dtype=arr_in.dtype) | |
for ic in corners: | |
islice = tuple( | |
get_slice(idig, idim, boundaries) for idim, idig in enumerate(ic) | |
) | |
arr_out[islice] = arr_in[islice] | |
return arr_out | |
def interpolate_fourier(arr_in: npt.NDArray, shape: list[int]) -> npt.NDArray: | |
"""Fourier interpolate an array. | |
Parameters | |
---------- | |
arr_in: | |
Input array of data | |
shape: | |
Desired shape shape of the interpolated data | |
Returns | |
------- | |
NDArray: | |
Interpolated data in the desired shape | |
""" | |
fft_res = np.fft.fftn(arr_in) | |
fft_res = pad_arr(fft_res, shape) | |
results = np.fft.ifftn(fft_res) * np.size(fft_res) / np.size(arr_in) | |
# take the real value if the input array is real | |
if not np.iscomplexobj(arr_in): | |
return np.real(results) | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment