Skip to content

Instantly share code, notes, and snippets.

@Sunmish
Last active February 20, 2019 08:56
Show Gist options
  • Save Sunmish/aff3e752bb564bcb85f1e22d5120f92d to your computer and use it in GitHub Desktop.
Save Sunmish/aff3e752bb564bcb85f1e22d5120f92d to your computer and use it in GitHub Desktop.
# ---------------------------------------------------------------------------- #
#
# Perform radial basis function interpolation on a sparse grid provided a
# reference array/image is available.
# Alternatively, perform nearest neighbout interpolation.
# Generalisation of code present in https://github.com/nhurleywalker/fits_warp
#
# ---------------------------------------------------------------------------- #
from __future__ import print_function, division
import numpy as np
import sys
import psutil
import os
from astropy.io import fits
from astropy.wcs import WCS
from scipy.interpolate import Rbf, NearestNDInterpolator, LinearNDInterpolator
from scipy.spatial import KDTree
import logging
logging.basicConfig(format="%(levelname)s (%(module)s): %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def rbf_(arr, x, y, z, interpolation="linear", smooth=0, epsilon=100,
memfrac=0.75, constrain=True, const=None, absmem="all",
verbose=True):
"""Sparse grid interpolation with a reference array."""
lenx, leny = len(x), len(y)
rbf_ref = np.full_like(np.squeeze(arr), 1. ,dtype=np.float32)
if interpolation == "nearest": # Not available in as rbf
z_rbf = NearestNDInterpolator(np.array([x, y]).T, z)
elif interpolation == "only_linear":
z_rbf = LinearNDInterpolator(np.array([x, y]).T, z)
else:
if constrain:
n_border = 10
xy_arr = np.array([x, y]).T
tree_ = KDTree(xy_arr)
original_z = z.copy()
# Here we add some border values to constraint the radial basis
# function so we do not have boundaries that go to large or small
# numbers.
logger.info("determining boundary pixels and values")
len_arr_x = arr.shape[-2]
len_arr_y = arr.shape[-1]
inc_x = len_arr_x // n_border
inc_y = len_arr_y // n_border
boundaries = [(0, 0), (len_arr_x, len_arr_y),
(len_arr_x, 0), (0, len_arr_y)]
for i in range(inc_x, len_arr_x-inc_x, inc_x):
for j in range(inc_y, len_arr_y-inc_y, inc_y):
boundaries.append((i, j))
for boundary in boundaries:
dist, idx = tree_.query(boundary)
z_b = original_z[idx]
x = np.append(x, boundary[0])
y = np.append(y, boundary[1])
z = np.append(z, z_b)
z_rbf = Rbf(x, y, z, function=interpolation, smooth=smooth,
epsilon=epsilon)
xy = np.indices(rbf_ref.shape, dtype=np.float32)
xy.shape = (2, xy.shape[1]*xy.shape[2])
all_x = np.array(xy[1, :])
all_y = np.array(xy[0, :])
all_z = np.full_like(all_x, 1., dtype=np.float32)
# Determine appropriate memory usage:
if absmem == "all":
mem = int(psutil.virtual_memory().available*memfrac)
else:
mem = absmem*1024.*1024.
pixmem = 40000
stride = mem // pixmem
stride = (stride//rbf_ref.shape[0])*rbf_ref.shape[0]
if stride == 0:
raise MemoryError("Not enough memory available for {:.0f}% memory "
" fraction.\n{} MB available\n"
"at least {} MB required.".format(memfrac*100,
mem/1024./1024.,
rbf_ref.shape[0] \
*pixmem/1024./1024.))
if len(all_x) > stride:
n = 0
borders = range(0, len(all_x)+1, stride)
if borders[-1] != len(all_x):
borders.append(len(all_x))
slices = [slice(a, b) for a, b in zip(borders[:-1], borders[1:])]
for s1 in slices:
if verbose:
sys.stdout.write(u"\u001b[1000D" + "{:.>6.1f}%".format(100.*n/len(borders)))
sys.stdout.flush()
zn = z_rbf(all_x[s1], all_y[s1])
all_z[s1] = zn
n += 1
print("")
else:
all_z = z_rbf(all_x, all_y)
rbf_ref[all_x.astype("i"), all_y.astype("i")] = all_z
return rbf_ref
def rbf(image, x, y, z, interpolation="linear", smooth=0, epsilon=2,
world_coords=False, outname=None, overwrite=True, const=None,
constrain=True, memfrac=0.75, absmem="all"):
"""Sparse grid interpolation with a reference FITS image.
Parameters
----------
image : str
Reference image.
x : np.ndarray
Pixel coordinates along first axis.
y : np.ndarray
Pixel coordinates along second axis.
z : np.ndarray
Values corresponding to x,y coordinates.
interpolation : str, optional
Interpolation method. Passed to scpiy.interpolate.Rbf. Additionally,
'nearest' is available using scipy.interpolate.NearestNDInterpolator.
[Default 'linear']
smooth : int, optional
Smoothness parameter passed to scipy.interpolate.Rbf. [Default 0]
epsilon : float, optional
Epsilon parameter passed to scipy.interpolate.Rbf. Only used for
certain inerpolation functions (e.g. Gaussian). [Default 2]
world_coords : bool, optional
True for x,y as world coordinates in decimal degrees. False for x,y as
pixel coordinates. [Default False]
outname : str, optional
Output image name. [Default interpolation+'.fits']
"""
with fits.open(image) as ref:
x = np.asarray(x)
y = np.asarray(y)
z = np.asarray(z)
if world_coords:
w = WCS(ref[0].header)
if not isinstance(x, np.ndarray):
x, y = np.asarray(x), np.asarray(y)
sub_wcs = w.celestial
y, x = sub_wcs.all_world2pix(x, y, 0)
x = x.astype("i")
y = y.astype("i")
ref[0].data = np.squeeze(ref[0].data) # We don't need the extra dimensions
rbf_ref = rbf_(ref[0].data, x, y, z, interpolation, smooth, epsilon,
constrain=constrain, const=const, memfrac=memfrac)
if outname is None:
outname = "{}.fits".format(interpolation)
fits.writeto(outname, rbf_ref, ref[0].header, clobber=overwrite)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment