Skip to content

Instantly share code, notes, and snippets.

@MGNute
Created August 13, 2020 17:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MGNute/5bf4fce2c23ac5c131e94e6b437fcc64 to your computer and use it in GitHub Desktop.
Save MGNute/5bf4fce2c23ac5c131e94e6b437fcc64 to your computer and use it in GitHub Desktop.
Fast NumPy Implementation of a Basic Thin-Plate Spline Model
import datetime, os, re
import numpy as np
import scipy as sp
VERBOSE=False
# Kernel_TPS = lambda x,y: np.nan_to_num(np.linalg.norm(y-x)**2*np.log(np.linalg.norm(y-x)),nan=0.)
Kernel_TPS = lambda x,y: 0. if np.linalg.norm(y-x)==0. else np.nan_to_num(np.linalg.norm(y-x)**2*np.log(np.linalg.norm(y-x)),nan=0.)
DATETIME_FMT_PRINTABLE = '%Y-%m-%d %H:%M:%S'
# Useful to maintain a set of check-points specifically for speed-testing and debugging.
checkpoint_list=[(datetime.datetime.now(), 'Module Loaded')]
def tps_kernel_matrix(xr, xc):
'''
This is an important function for this module. This takes two sets of coordinates (n x 2 column vectors)
and returns the kernel matrix evaluated over them, pairwise. Speicifcally, this function implements the thin-plate
spline kernel: k(r)=r^2 * log(r), although it does it exclusively with numpy functions which speeds this
function WAY up.
Both inputs are matrices of width 2 (coordinate matrices). 'xr' is the coordinates that will
vary down the rows of the kernel matrix. 'xc' is the ones that vary across columns.
'''
div_curr = np.geterr()['divide']; inv_curr=np.geterr()['invalid'];
np.seterr(divide='ignore', invalid='ignore')
nr = xr.shape[0]
nc = xc.shape[0]
R = np.linalg.norm(np.dstack((np.repeat(xr[:, 0], nc).reshape((nr, nc)) - np.tile(xc[:, 0], nr).reshape((nr, nc)),
np.repeat(xr[:, 1], nc).reshape((nr, nc)) - np.tile(xc[:, 1], nr).reshape((nr, nc)))),axis=2)
Rz = np.where(R==0.)
K = np.multiply(R**2, np.log(R))
K[Rz]=0.0
np.seterr(divide=div_curr, invalid=inv_curr)
return K
class ThinPlateSpline():
'''
This class implements a generic Thin-Plate spline model. It has routines for fitting the model to a set of
training data as well as for evaluating the model on arbitrary points. It does not have any of the bells and
whistles, for example a semi-parametric version or the multi-layer version used in the bias calc functions
later on.
'''
X=None
Y=None
n=None
K = None
lambdasmooth = 0.1
w_hat = None; a_hat = None;
Y_hat = None; e_hat = None;
std_err_resid = None;
Xnew = None; Ynew_hat = None; Knew = None;
prediction_grid_points = None;
prediction_grid_values = None;
prediction_grid_image = None;
map_bounds_NSWE = (50., 20., -130., -60.);
lats_list = None; lons_list = None;
self_pickle_file_path = None;
def __init__(self, Y=None, X=None, verbose=True, lambda_smoothing=None, from_file_path=None):
'''
Initialize by giving Y and X as numpy arrays. Y must be (n x 1) and X must be (n x 3) with the first coordinate
set to 1.0 (i.e. homogenous coordinates).
:param Y: numpy array (N x 1), dependent variable
:param X: numpy array (N x 3), locations in homogenous coordinates (i.e. [1.0, x, y])
'''
# if a file path to load is specified, default to that...
if from_file_path is not None and os.path.isfile(from_file_path):
self.load_from_pickle_file(from_file_path)
return
# ...otherwise use the inputs (if any are given)
if X is not None and Y is not None:
Yn = Y.shape[0]
Xn = X.shape[0]
assert Yn==Xn, 'Y and X must be the same height. (Y is shape %s, X is shape %s)' % (str(Y.shape), str(X.shape))
assert Y.shape[1] == 1, 'Y must have width 1. (Y is shape %s)' % str(Y.shape)
assert X.shape[1] == 3, 'X must have width 3. (X is shape %s)' % str(X.shape)
assert np.all(X[:,0]==1.0), 'X must have 1.0 in the first column of every row.'
self.n = Xn
self.Y = Y
self.X = X
if verbose:
self.verbose = True
else:
self.verbose = False
if lambda_smoothing is not None:
self.lambdasmooth = lambda_smoothing
def save_as_pickle_file(self, file_path=None):
'''Saves this object as a pickle file rather than a giant numpy array. (Or tries to anyway, I'm not sure
if this function works or has been tested). Basically just saves a python dictionary of the most important
attributes.
'''
if file_path is not None:
self.self_pickle_file_path = file_path
if self.self_pickle_file_path is None:
raise FileNotFoundError
tps_data = {
'X': self.X,
'Y': self.Y,
'n': self.n,
'K': self.K,
'lambdasmooth': self.lambdasmooth,
'w_hat': self.w_hat,
'a_hat': self.a_hat,
'Y_hat': self.Y_hat,
'e_hat': self.e_hat,
'std_err_resid': self.std_err_resid,
'Xnew': self.Xnew,
'Ynew_hat': self.Ynew_hat,
'Knew': self.Knew,
'self_pickle_file_path': self.self_pickle_file_path
}
io_save_to_pickle(tps_data, self.self_pickle_file_path)
def load_from_pickle_file(self, file_path=None):
'''Loads the data for a previously created TPS from a pickle file. Must be in the same form as saved
by self.save_to_pickle_file.
'''
if file_path is not None:
self.self_pickle_file_path = file_path
if self.self_pickle_file_path is None:
raise FileNotFoundError
tps_data = io_load_from_pickle(self.self_pickle_file_path)
self.X = tps_data['X']
self.Y = tps_data['Y']
self.n = tps_data['n']
self.K = tps_data['K']
self.lambdasmooth = tps_data['lambdasmooth']
self.w_hat = tps_data['w_hat']
self.a_hat = tps_data['a_hat']
self.Y_hat = tps_data['Y_hat']
self.e_hat = tps_data['e_hat']
self.std_err_resid = tps_data['std_err_resid']
self.Xnew = tps_data['Xnew']
self.Ynew_hat = tps_data['Ynew_hat']
self.Knew = tps_data['Knew']
def copy_from_existing_TPS(self, existing_TPS):
'''
Takes an existing ThinPlateSpline object and turns self into a copy.
'''
self.n = existing_TPS.n
self.Y = existing_TPS.Y
self.X = existing_TPS.X
self.verbose = existing_TPS.verbose
self.Kernel_Xnew = existing_TPS.Kernel_Xnew
self.K = existing_TPS.K
self.w_hat = existing_TPS.w_hat
self.a_hat = existing_TPS.a_hat
self.Y_hat = existing_TPS.Y_hat
self.e_hat = existing_TPS.e_hat
self.std_err_resid = existing_TPS.std_err_resid
def compute_kernel_matrix(self):
'''
Computes the (N x N) kernel matrix for the training data points X_1, .... , X_N
'''
st = datetime.datetime.now()
if self.verbose:
print('Computing K matrix beginning at %s' % st.strftime(DATETIME_FMT_PRINTABLE), end = '', flush=True)
self.K = tps_kernel_matrix(self.X[:,1:], self.X[:,1:])
if self.verbose:
et = datetime.datetime.now()
print(', done... (took %s)' % (et - st))
def solve_parameters(self):
'''
Solves the TP-Spline for the input values. Gets the W and A parameter vectors and additionally computes
fitted Y values and residuals.
'''
if self.K is None:
if self.verbose:
print('Kernel matrix has not yet been computed. Doing that first.')
self.compute_kernel_matrix()
st = datetime.datetime.now()
if self.verbose:
print('Solving thin-plate spline params, beginning at %s' % st, end='', flush=True)
L = np.vstack((np.hstack((self.K + self.n * self.lambdasmooth * np.identity(self.n), self.X)),
np.hstack((self.X.T, np.zeros((3, 3), dtype=np.float64)))))
wa_hat = np.linalg.inv(L).dot(np.vstack((self.Y, np.zeros((3, 1), dtype=np.float64))))
self.w_hat = wa_hat[:self.n]
self.a_hat = wa_hat[self.n:]
self.fit_Y_hat()
if self.verbose:
et = datetime.datetime.now()
print(', done... (took %s)' % (et - st))
def fit_Y_hat(self):
'''
For all the values of the training data, computes the 'fitted' predicted value 'Y-hat'. Also computes the
residual 'e_hat' and the model's standard error.
:return:
'''
self.Y_hat = self.X.dot(self.a_hat) + self.K.dot(self.w_hat)
self.e_hat = self.Y - self.Y_hat
self.std_err_resid = np.std(self.e_hat).item()
def report_diagnostics(self):
'''
This is a work in progress. Essentially should be sort of like a 'summary' function in R. At the very least
we want sample size, std err and that sort of stuff.
'''
print('Sample Size: %s' % self.n)
print('Mean Residual: %s' % np.mean(self.e_hat).item())
print('Residual S.D.: %s' % np.std(self.e_hat).item())
def predict(self, Xnew, store_XYKnew = False):
'''
Computes predicted value of the spline at a new point(s) Xnew. Xnew must be in the same format at self.X (i.e.
for K different points it should be (K x 3) with each row in homogenous coordinates: [1.0, x, y]). Since a set
of predictions can range in its importance to the model itself, there is the option to store this information with
the model so it may not need to be calculated again. There is also a function clear_predictions() to remove any
old predictions.
:param Xnew: (K x 3) numpy array of points in X-domain for which to compute fitted value.
:return: (K x 1) numpy array of predicted values
'''
k = Xnew.shape[0]
assert Xnew.shape[1] == 3, 'variable Xnew must have width 3'
Kmat_Xnew = tps_kernel_matrix(Xnew[:,1:], self.X[:,1:])
Ynew_hat = Xnew.dot(self.a_hat) + Kmat_Xnew.dot(self.w_hat)
if store_XYKnew:
self.clear_predictions()
self.Xnew = Xnew
self.Ynew_hat = Ynew_hat
self.Knew = Kmat_Xnew
else:
return Ynew_hat
def clear_predictions(self):
'''Removes old values of self.Xnew, self.Ynew_hat, self.Knew'''
if self.Xnew is not None:
del self.Xnew; self.Xnew = None;
if self.Ynew_hat is not None:
del self.Ynew_hat; self.Ynew_hat = None;
if self.Knew is not None:
del self.Knew; self.Knew = None;
def grid_predict(self, lon_resolution = 1.0, lat_resolution = 1.0, map_bounds_NSWE=None):
'''
Identifies a proper grid of points to serve as prediction points for mapping the spline's predicted
vTEC values. Grid should ideally not go far outside the range of the training data. Points are spaced
every 'lon_resolution' in the horizontal direction and 'lat_resolution' vertically.
'''
if map_bounds_NSWE is not None:
self.map_bounds_NSWE = map_bounds_NSWE
prediction_grid_point_list = []
self.lons_list, self.lats_list, lon_res, lat_res= get_lat_lon_lists(lon_resolution, lat_resolution, self.map_bounds_NSWE)
n_lons = len(self.lons_list); n_lats = len(self.lats_list);
for mylat in self.lats_list:
mylons = get_lons_in_Xrange_for_horizontal_band(self.X, mylat, lat_resolution, self.lons_list, lon_resolution)
prediction_grid_point_list += list(map(lambda x: (1., x, mylat), mylons))
self.prediction_grid_points = np.array(prediction_grid_point_list, dtype=np.float64)
self.prediction_grid_values = self.predict(self.prediction_grid_points)
self.prediction_grid_image = np.empty((n_lats, n_lons), dtype=np.float64)
self.prediction_grid_image[:,:]=np.nan
# fill in the prediction image:
for i in range(self.prediction_grid_points.shape[0]):
img_r = self.lats_list.index(self.prediction_grid_points[i,2].item())
img_c = self.lons_list.index(self.prediction_grid_points[i,1].item())
self.prediction_grid_image[img_r, img_c] = self.prediction_grid_values[i]
#
#DEBUGGING/MISC UTILITY FUNCTIONS:
#
def np_dtype_size(dt):
'''DEBUGGING function. Takes a numpy.dtype and converts that into the number of bytes per data item. E.g. np.float64
will return 8'''
dts=str(dt)
m = re.search('[0-9]+$',dts)
return int(int(m.group(0))/8)
def get_nparray_size_data(myarr, return_str=False):
'''DEBUGGING function. Useful for debugging. Returns a tuple (or optionally a string representation thereof) of
some key bits of data about how big a matrix is: (# bytes, shape, size, data-type)'''
if sp.sparse.issparse(myarr):
if sp.sparse.isspmatrix_csr(myarr):
tot_bytes=myarr.data.nbytes + myarr.indptr.nbytes + myarr.indices.nbytes
else:
tot_bytes=myarr.data.nbytes
else:
tot_bytes=myarr.nbytes
shp=myarr.shape
sz=myarr.size
dt=myarr.dtype
if return_str:
return '(%d, %s, %d, %s)' % (tot_bytes, shp, sz, dt)
return tot_bytes, shp, sz, dt
def checkpoint(msg, verbose=VERBOSE):
'''DEBUGGING function. helper function for debugging. optionally prints a message but mainly stores the time
points at every check point.'''
t = datetime.datetime.now()
last_t = checkpoint_list[-1][0]
first_t = checkpoint_list[0][0]
td1=t-last_t
td2=t-first_t
td1p=datetime.timedelta(td1.days, td1.seconds, 0)
td2p = datetime.timedelta(td2.days, td2.seconds, 0)
if verbose:
print('%s - %s (%s cume) - %s (GB) -- %s' % (t.strftime(DATETIME_FMT_PRINTABLE), td1p, td2p, msg))
checkpoint_list.append((t, msg))
def io_save_to_pickle(dat, file_path):
'''Simple wrapper function to save time when saving data as a pickled file.'''
import pickle
tgt_path = os.path.abspath(file_path)
with open(tgt_path, 'wb') as dat_f:
pickle.dump(dat, dat_f)
def io_load_from_pickle(file_path):
'''Simple wrapper function to save time when loading data as a pickled file.'''
import pickle
tgt_path = os.path.abspath(file_path)
with open(tgt_path, 'rb') as dat_f:
dat = pickle.load(dat_f)
return dat
def download_tps_sample_data():
'''For testing and illustration, I've stuck some sample lat-lon data with a dependent variable on dropbox.
This function goes and gets that data and returns it as two numpy arrays suitable for testing and a basic
example.'''
lnk = 'https://www.dropbox.com/s/raw/m35fbcorigw7l95/test_spline_data.txt'
import urllib.request
dat = urllib.request.urlopen(lnk).readlines()
y_list = eval(dat[0].strip())
X=np.array([eval(dat[1].strip()), eval(dat[2].strip())],dtype=np.float64).T
Y = np.array(y_list, dtype=np.float64)
return Y,X
def get_lat_lon_lists(lon_res, lat_res, NSWE_bounds):
'''Returns two lists of points, one for lat and one for lon, that are spaced evenly between the bounds of the map.
Specifically, the desired resolution is adjusted slightly downward to fit an integer number of steps exactly
on the map. The resulting actual resolutions are returned along with the lists.
NOTE: assumes that the east longitude is greater than the west longitude, which may not necessarily be true.
'''
#TODO: make it so this can handle when East < West due to stradling the international date line.
N = max(NSWE_bounds[0], NSWE_bounds[1])
S = min(NSWE_bounds[0], NSWE_bounds[1])
E = NSWE_bounds[3]
W = NSWE_bounds[2]
NSdist = (N-S)
EWdist = (E-W) if E>W else 360.-abs(E-W) #if bounds cross dateline then take complement
longitude_mod_func = lambda x: (x + 180.) % 360 - 180.
n_lat_levels = int(NSdist/lat_res) + 1
if NSdist/lat_res==int(NSdist/lat_res):
n_lat_levels -= 1
n_lon_levels = int(EWdist / lon_res) + 1
if EWdist/lon_res==int(EWdist/lon_res):
n_lon_levels -= 1
lat_res_adj = NSdist / n_lat_levels
lon_res_adj = EWdist / n_lon_levels
lat_list = list(map(lambda x: S + x*lat_res_adj, range(n_lat_levels+1)))
lon_list = list(map(lambda x: longitude_mod_func(W + x * lon_res_adj), range(n_lon_levels + 1)))
return lon_list, lat_list, lon_res_adj, lat_res_adj
def get_lons_in_Xrange_for_horizontal_band(X, center_lat, lat_halfwidth, lon_list, lon_res):
'''For a particular horizontal band on the map, return a grid of evenly spaced longitudes
contained within the range of the X-values in that band.'''
# TODO: figure out what to do if East-West map boundaires span the international date line.
minlat = center_lat - lat_halfwidth;
maxlat = center_lat + lat_halfwidth;
hband_x_positions = np.where((X[:, 2] >= minlat) & (X[:, 2] <= maxlat))[0]
minXlon = np.min(X[hband_x_positions, 1])
maxXlon = np.max(X[hband_x_positions, 1])
out_lons = list(filter(lambda x: (x > minXlon - lon_res) and (x < maxXlon + lon_res), lon_list))
out_lons.sort()
return out_lons
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment