Last active
February 20, 2019 08:56
-
-
Save Sunmish/aff3e752bb564bcb85f1e22d5120f92d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ---------------------------------------------------------------------------- # | |
# | |
# Perform radial basis function interpolation on a sparse grid provided a | |
# reference array/image is available. | |
# Alternatively, perform nearest neighbout interpolation. | |
# Generalisation of code present in https://github.com/nhurleywalker/fits_warp | |
# | |
# ---------------------------------------------------------------------------- # | |
from __future__ import print_function, division | |
import numpy as np | |
import sys | |
import psutil | |
import os | |
from astropy.io import fits | |
from astropy.wcs import WCS | |
from scipy.interpolate import Rbf, NearestNDInterpolator, LinearNDInterpolator | |
from scipy.spatial import KDTree | |
import logging | |
logging.basicConfig(format="%(levelname)s (%(module)s): %(message)s") | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
def rbf_(arr, x, y, z, interpolation="linear", smooth=0, epsilon=100, | |
memfrac=0.75, constrain=True, const=None, absmem="all", | |
verbose=True): | |
"""Sparse grid interpolation with a reference array.""" | |
lenx, leny = len(x), len(y) | |
rbf_ref = np.full_like(np.squeeze(arr), 1. ,dtype=np.float32) | |
if interpolation == "nearest": # Not available in as rbf | |
z_rbf = NearestNDInterpolator(np.array([x, y]).T, z) | |
elif interpolation == "only_linear": | |
z_rbf = LinearNDInterpolator(np.array([x, y]).T, z) | |
else: | |
if constrain: | |
n_border = 10 | |
xy_arr = np.array([x, y]).T | |
tree_ = KDTree(xy_arr) | |
original_z = z.copy() | |
# Here we add some border values to constraint the radial basis | |
# function so we do not have boundaries that go to large or small | |
# numbers. | |
logger.info("determining boundary pixels and values") | |
len_arr_x = arr.shape[-2] | |
len_arr_y = arr.shape[-1] | |
inc_x = len_arr_x // n_border | |
inc_y = len_arr_y // n_border | |
boundaries = [(0, 0), (len_arr_x, len_arr_y), | |
(len_arr_x, 0), (0, len_arr_y)] | |
for i in range(inc_x, len_arr_x-inc_x, inc_x): | |
for j in range(inc_y, len_arr_y-inc_y, inc_y): | |
boundaries.append((i, j)) | |
for boundary in boundaries: | |
dist, idx = tree_.query(boundary) | |
z_b = original_z[idx] | |
x = np.append(x, boundary[0]) | |
y = np.append(y, boundary[1]) | |
z = np.append(z, z_b) | |
z_rbf = Rbf(x, y, z, function=interpolation, smooth=smooth, | |
epsilon=epsilon) | |
xy = np.indices(rbf_ref.shape, dtype=np.float32) | |
xy.shape = (2, xy.shape[1]*xy.shape[2]) | |
all_x = np.array(xy[1, :]) | |
all_y = np.array(xy[0, :]) | |
all_z = np.full_like(all_x, 1., dtype=np.float32) | |
# Determine appropriate memory usage: | |
if absmem == "all": | |
mem = int(psutil.virtual_memory().available*memfrac) | |
else: | |
mem = absmem*1024.*1024. | |
pixmem = 40000 | |
stride = mem // pixmem | |
stride = (stride//rbf_ref.shape[0])*rbf_ref.shape[0] | |
if stride == 0: | |
raise MemoryError("Not enough memory available for {:.0f}% memory " | |
" fraction.\n{} MB available\n" | |
"at least {} MB required.".format(memfrac*100, | |
mem/1024./1024., | |
rbf_ref.shape[0] \ | |
*pixmem/1024./1024.)) | |
if len(all_x) > stride: | |
n = 0 | |
borders = range(0, len(all_x)+1, stride) | |
if borders[-1] != len(all_x): | |
borders.append(len(all_x)) | |
slices = [slice(a, b) for a, b in zip(borders[:-1], borders[1:])] | |
for s1 in slices: | |
if verbose: | |
sys.stdout.write(u"\u001b[1000D" + "{:.>6.1f}%".format(100.*n/len(borders))) | |
sys.stdout.flush() | |
zn = z_rbf(all_x[s1], all_y[s1]) | |
all_z[s1] = zn | |
n += 1 | |
print("") | |
else: | |
all_z = z_rbf(all_x, all_y) | |
rbf_ref[all_x.astype("i"), all_y.astype("i")] = all_z | |
return rbf_ref | |
def rbf(image, x, y, z, interpolation="linear", smooth=0, epsilon=2, | |
world_coords=False, outname=None, overwrite=True, const=None, | |
constrain=True, memfrac=0.75, absmem="all"): | |
"""Sparse grid interpolation with a reference FITS image. | |
Parameters | |
---------- | |
image : str | |
Reference image. | |
x : np.ndarray | |
Pixel coordinates along first axis. | |
y : np.ndarray | |
Pixel coordinates along second axis. | |
z : np.ndarray | |
Values corresponding to x,y coordinates. | |
interpolation : str, optional | |
Interpolation method. Passed to scpiy.interpolate.Rbf. Additionally, | |
'nearest' is available using scipy.interpolate.NearestNDInterpolator. | |
[Default 'linear'] | |
smooth : int, optional | |
Smoothness parameter passed to scipy.interpolate.Rbf. [Default 0] | |
epsilon : float, optional | |
Epsilon parameter passed to scipy.interpolate.Rbf. Only used for | |
certain inerpolation functions (e.g. Gaussian). [Default 2] | |
world_coords : bool, optional | |
True for x,y as world coordinates in decimal degrees. False for x,y as | |
pixel coordinates. [Default False] | |
outname : str, optional | |
Output image name. [Default interpolation+'.fits'] | |
""" | |
with fits.open(image) as ref: | |
x = np.asarray(x) | |
y = np.asarray(y) | |
z = np.asarray(z) | |
if world_coords: | |
w = WCS(ref[0].header) | |
if not isinstance(x, np.ndarray): | |
x, y = np.asarray(x), np.asarray(y) | |
sub_wcs = w.celestial | |
y, x = sub_wcs.all_world2pix(x, y, 0) | |
x = x.astype("i") | |
y = y.astype("i") | |
ref[0].data = np.squeeze(ref[0].data) # We don't need the extra dimensions | |
rbf_ref = rbf_(ref[0].data, x, y, z, interpolation, smooth, epsilon, | |
constrain=constrain, const=const, memfrac=memfrac) | |
if outname is None: | |
outname = "{}.fits".format(interpolation) | |
fits.writeto(outname, rbf_ref, ref[0].header, clobber=overwrite) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment