Created
August 13, 2020 17:03
-
-
Save MGNute/5bf4fce2c23ac5c131e94e6b437fcc64 to your computer and use it in GitHub Desktop.
Fast NumPy Implementation of a Basic Thin-Plate Spline Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime, os, re | |
import numpy as np | |
import scipy as sp | |
VERBOSE=False | |
# Kernel_TPS = lambda x,y: np.nan_to_num(np.linalg.norm(y-x)**2*np.log(np.linalg.norm(y-x)),nan=0.) | |
Kernel_TPS = lambda x,y: 0. if np.linalg.norm(y-x)==0. else np.nan_to_num(np.linalg.norm(y-x)**2*np.log(np.linalg.norm(y-x)),nan=0.) | |
DATETIME_FMT_PRINTABLE = '%Y-%m-%d %H:%M:%S' | |
# Useful to maintain a set of check-points specifically for speed-testing and debugging. | |
checkpoint_list=[(datetime.datetime.now(), 'Module Loaded')] | |
def tps_kernel_matrix(xr, xc): | |
''' | |
This is an important function for this module. This takes two sets of coordinates (n x 2 column vectors) | |
and returns the kernel matrix evaluated over them, pairwise. Speicifcally, this function implements the thin-plate | |
spline kernel: k(r)=r^2 * log(r), although it does it exclusively with numpy functions which speeds this | |
function WAY up. | |
Both inputs are matrices of width 2 (coordinate matrices). 'xr' is the coordinates that will | |
vary down the rows of the kernel matrix. 'xc' is the ones that vary across columns. | |
''' | |
div_curr = np.geterr()['divide']; inv_curr=np.geterr()['invalid']; | |
np.seterr(divide='ignore', invalid='ignore') | |
nr = xr.shape[0] | |
nc = xc.shape[0] | |
R = np.linalg.norm(np.dstack((np.repeat(xr[:, 0], nc).reshape((nr, nc)) - np.tile(xc[:, 0], nr).reshape((nr, nc)), | |
np.repeat(xr[:, 1], nc).reshape((nr, nc)) - np.tile(xc[:, 1], nr).reshape((nr, nc)))),axis=2) | |
Rz = np.where(R==0.) | |
K = np.multiply(R**2, np.log(R)) | |
K[Rz]=0.0 | |
np.seterr(divide=div_curr, invalid=inv_curr) | |
return K | |
class ThinPlateSpline(): | |
''' | |
This class implements a generic Thin-Plate spline model. It has routines for fitting the model to a set of | |
training data as well as for evaluating the model on arbitrary points. It does not have any of the bells and | |
whistles, for example a semi-parametric version or the multi-layer version used in the bias calc functions | |
later on. | |
''' | |
X=None | |
Y=None | |
n=None | |
K = None | |
lambdasmooth = 0.1 | |
w_hat = None; a_hat = None; | |
Y_hat = None; e_hat = None; | |
std_err_resid = None; | |
Xnew = None; Ynew_hat = None; Knew = None; | |
prediction_grid_points = None; | |
prediction_grid_values = None; | |
prediction_grid_image = None; | |
map_bounds_NSWE = (50., 20., -130., -60.); | |
lats_list = None; lons_list = None; | |
self_pickle_file_path = None; | |
def __init__(self, Y=None, X=None, verbose=True, lambda_smoothing=None, from_file_path=None): | |
''' | |
Initialize by giving Y and X as numpy arrays. Y must be (n x 1) and X must be (n x 3) with the first coordinate | |
set to 1.0 (i.e. homogenous coordinates). | |
:param Y: numpy array (N x 1), dependent variable | |
:param X: numpy array (N x 3), locations in homogenous coordinates (i.e. [1.0, x, y]) | |
''' | |
# if a file path to load is specified, default to that... | |
if from_file_path is not None and os.path.isfile(from_file_path): | |
self.load_from_pickle_file(from_file_path) | |
return | |
# ...otherwise use the inputs (if any are given) | |
if X is not None and Y is not None: | |
Yn = Y.shape[0] | |
Xn = X.shape[0] | |
assert Yn==Xn, 'Y and X must be the same height. (Y is shape %s, X is shape %s)' % (str(Y.shape), str(X.shape)) | |
assert Y.shape[1] == 1, 'Y must have width 1. (Y is shape %s)' % str(Y.shape) | |
assert X.shape[1] == 3, 'X must have width 3. (X is shape %s)' % str(X.shape) | |
assert np.all(X[:,0]==1.0), 'X must have 1.0 in the first column of every row.' | |
self.n = Xn | |
self.Y = Y | |
self.X = X | |
if verbose: | |
self.verbose = True | |
else: | |
self.verbose = False | |
if lambda_smoothing is not None: | |
self.lambdasmooth = lambda_smoothing | |
def save_as_pickle_file(self, file_path=None): | |
'''Saves this object as a pickle file rather than a giant numpy array. (Or tries to anyway, I'm not sure | |
if this function works or has been tested). Basically just saves a python dictionary of the most important | |
attributes. | |
''' | |
if file_path is not None: | |
self.self_pickle_file_path = file_path | |
if self.self_pickle_file_path is None: | |
raise FileNotFoundError | |
tps_data = { | |
'X': self.X, | |
'Y': self.Y, | |
'n': self.n, | |
'K': self.K, | |
'lambdasmooth': self.lambdasmooth, | |
'w_hat': self.w_hat, | |
'a_hat': self.a_hat, | |
'Y_hat': self.Y_hat, | |
'e_hat': self.e_hat, | |
'std_err_resid': self.std_err_resid, | |
'Xnew': self.Xnew, | |
'Ynew_hat': self.Ynew_hat, | |
'Knew': self.Knew, | |
'self_pickle_file_path': self.self_pickle_file_path | |
} | |
io_save_to_pickle(tps_data, self.self_pickle_file_path) | |
def load_from_pickle_file(self, file_path=None): | |
'''Loads the data for a previously created TPS from a pickle file. Must be in the same form as saved | |
by self.save_to_pickle_file. | |
''' | |
if file_path is not None: | |
self.self_pickle_file_path = file_path | |
if self.self_pickle_file_path is None: | |
raise FileNotFoundError | |
tps_data = io_load_from_pickle(self.self_pickle_file_path) | |
self.X = tps_data['X'] | |
self.Y = tps_data['Y'] | |
self.n = tps_data['n'] | |
self.K = tps_data['K'] | |
self.lambdasmooth = tps_data['lambdasmooth'] | |
self.w_hat = tps_data['w_hat'] | |
self.a_hat = tps_data['a_hat'] | |
self.Y_hat = tps_data['Y_hat'] | |
self.e_hat = tps_data['e_hat'] | |
self.std_err_resid = tps_data['std_err_resid'] | |
self.Xnew = tps_data['Xnew'] | |
self.Ynew_hat = tps_data['Ynew_hat'] | |
self.Knew = tps_data['Knew'] | |
def copy_from_existing_TPS(self, existing_TPS): | |
''' | |
Takes an existing ThinPlateSpline object and turns self into a copy. | |
''' | |
self.n = existing_TPS.n | |
self.Y = existing_TPS.Y | |
self.X = existing_TPS.X | |
self.verbose = existing_TPS.verbose | |
self.Kernel_Xnew = existing_TPS.Kernel_Xnew | |
self.K = existing_TPS.K | |
self.w_hat = existing_TPS.w_hat | |
self.a_hat = existing_TPS.a_hat | |
self.Y_hat = existing_TPS.Y_hat | |
self.e_hat = existing_TPS.e_hat | |
self.std_err_resid = existing_TPS.std_err_resid | |
def compute_kernel_matrix(self): | |
''' | |
Computes the (N x N) kernel matrix for the training data points X_1, .... , X_N | |
''' | |
st = datetime.datetime.now() | |
if self.verbose: | |
print('Computing K matrix beginning at %s' % st.strftime(DATETIME_FMT_PRINTABLE), end = '', flush=True) | |
self.K = tps_kernel_matrix(self.X[:,1:], self.X[:,1:]) | |
if self.verbose: | |
et = datetime.datetime.now() | |
print(', done... (took %s)' % (et - st)) | |
def solve_parameters(self): | |
''' | |
Solves the TP-Spline for the input values. Gets the W and A parameter vectors and additionally computes | |
fitted Y values and residuals. | |
''' | |
if self.K is None: | |
if self.verbose: | |
print('Kernel matrix has not yet been computed. Doing that first.') | |
self.compute_kernel_matrix() | |
st = datetime.datetime.now() | |
if self.verbose: | |
print('Solving thin-plate spline params, beginning at %s' % st, end='', flush=True) | |
L = np.vstack((np.hstack((self.K + self.n * self.lambdasmooth * np.identity(self.n), self.X)), | |
np.hstack((self.X.T, np.zeros((3, 3), dtype=np.float64))))) | |
wa_hat = np.linalg.inv(L).dot(np.vstack((self.Y, np.zeros((3, 1), dtype=np.float64)))) | |
self.w_hat = wa_hat[:self.n] | |
self.a_hat = wa_hat[self.n:] | |
self.fit_Y_hat() | |
if self.verbose: | |
et = datetime.datetime.now() | |
print(', done... (took %s)' % (et - st)) | |
def fit_Y_hat(self): | |
''' | |
For all the values of the training data, computes the 'fitted' predicted value 'Y-hat'. Also computes the | |
residual 'e_hat' and the model's standard error. | |
:return: | |
''' | |
self.Y_hat = self.X.dot(self.a_hat) + self.K.dot(self.w_hat) | |
self.e_hat = self.Y - self.Y_hat | |
self.std_err_resid = np.std(self.e_hat).item() | |
def report_diagnostics(self): | |
''' | |
This is a work in progress. Essentially should be sort of like a 'summary' function in R. At the very least | |
we want sample size, std err and that sort of stuff. | |
''' | |
print('Sample Size: %s' % self.n) | |
print('Mean Residual: %s' % np.mean(self.e_hat).item()) | |
print('Residual S.D.: %s' % np.std(self.e_hat).item()) | |
def predict(self, Xnew, store_XYKnew = False): | |
''' | |
Computes predicted value of the spline at a new point(s) Xnew. Xnew must be in the same format at self.X (i.e. | |
for K different points it should be (K x 3) with each row in homogenous coordinates: [1.0, x, y]). Since a set | |
of predictions can range in its importance to the model itself, there is the option to store this information with | |
the model so it may not need to be calculated again. There is also a function clear_predictions() to remove any | |
old predictions. | |
:param Xnew: (K x 3) numpy array of points in X-domain for which to compute fitted value. | |
:return: (K x 1) numpy array of predicted values | |
''' | |
k = Xnew.shape[0] | |
assert Xnew.shape[1] == 3, 'variable Xnew must have width 3' | |
Kmat_Xnew = tps_kernel_matrix(Xnew[:,1:], self.X[:,1:]) | |
Ynew_hat = Xnew.dot(self.a_hat) + Kmat_Xnew.dot(self.w_hat) | |
if store_XYKnew: | |
self.clear_predictions() | |
self.Xnew = Xnew | |
self.Ynew_hat = Ynew_hat | |
self.Knew = Kmat_Xnew | |
else: | |
return Ynew_hat | |
def clear_predictions(self): | |
'''Removes old values of self.Xnew, self.Ynew_hat, self.Knew''' | |
if self.Xnew is not None: | |
del self.Xnew; self.Xnew = None; | |
if self.Ynew_hat is not None: | |
del self.Ynew_hat; self.Ynew_hat = None; | |
if self.Knew is not None: | |
del self.Knew; self.Knew = None; | |
def grid_predict(self, lon_resolution = 1.0, lat_resolution = 1.0, map_bounds_NSWE=None): | |
''' | |
Identifies a proper grid of points to serve as prediction points for mapping the spline's predicted | |
vTEC values. Grid should ideally not go far outside the range of the training data. Points are spaced | |
every 'lon_resolution' in the horizontal direction and 'lat_resolution' vertically. | |
''' | |
if map_bounds_NSWE is not None: | |
self.map_bounds_NSWE = map_bounds_NSWE | |
prediction_grid_point_list = [] | |
self.lons_list, self.lats_list, lon_res, lat_res= get_lat_lon_lists(lon_resolution, lat_resolution, self.map_bounds_NSWE) | |
n_lons = len(self.lons_list); n_lats = len(self.lats_list); | |
for mylat in self.lats_list: | |
mylons = get_lons_in_Xrange_for_horizontal_band(self.X, mylat, lat_resolution, self.lons_list, lon_resolution) | |
prediction_grid_point_list += list(map(lambda x: (1., x, mylat), mylons)) | |
self.prediction_grid_points = np.array(prediction_grid_point_list, dtype=np.float64) | |
self.prediction_grid_values = self.predict(self.prediction_grid_points) | |
self.prediction_grid_image = np.empty((n_lats, n_lons), dtype=np.float64) | |
self.prediction_grid_image[:,:]=np.nan | |
# fill in the prediction image: | |
for i in range(self.prediction_grid_points.shape[0]): | |
img_r = self.lats_list.index(self.prediction_grid_points[i,2].item()) | |
img_c = self.lons_list.index(self.prediction_grid_points[i,1].item()) | |
self.prediction_grid_image[img_r, img_c] = self.prediction_grid_values[i] | |
# | |
#DEBUGGING/MISC UTILITY FUNCTIONS: | |
# | |
def np_dtype_size(dt): | |
'''DEBUGGING function. Takes a numpy.dtype and converts that into the number of bytes per data item. E.g. np.float64 | |
will return 8''' | |
dts=str(dt) | |
m = re.search('[0-9]+$',dts) | |
return int(int(m.group(0))/8) | |
def get_nparray_size_data(myarr, return_str=False): | |
'''DEBUGGING function. Useful for debugging. Returns a tuple (or optionally a string representation thereof) of | |
some key bits of data about how big a matrix is: (# bytes, shape, size, data-type)''' | |
if sp.sparse.issparse(myarr): | |
if sp.sparse.isspmatrix_csr(myarr): | |
tot_bytes=myarr.data.nbytes + myarr.indptr.nbytes + myarr.indices.nbytes | |
else: | |
tot_bytes=myarr.data.nbytes | |
else: | |
tot_bytes=myarr.nbytes | |
shp=myarr.shape | |
sz=myarr.size | |
dt=myarr.dtype | |
if return_str: | |
return '(%d, %s, %d, %s)' % (tot_bytes, shp, sz, dt) | |
return tot_bytes, shp, sz, dt | |
def checkpoint(msg, verbose=VERBOSE): | |
'''DEBUGGING function. helper function for debugging. optionally prints a message but mainly stores the time | |
points at every check point.''' | |
t = datetime.datetime.now() | |
last_t = checkpoint_list[-1][0] | |
first_t = checkpoint_list[0][0] | |
td1=t-last_t | |
td2=t-first_t | |
td1p=datetime.timedelta(td1.days, td1.seconds, 0) | |
td2p = datetime.timedelta(td2.days, td2.seconds, 0) | |
if verbose: | |
print('%s - %s (%s cume) - %s (GB) -- %s' % (t.strftime(DATETIME_FMT_PRINTABLE), td1p, td2p, msg)) | |
checkpoint_list.append((t, msg)) | |
def io_save_to_pickle(dat, file_path): | |
'''Simple wrapper function to save time when saving data as a pickled file.''' | |
import pickle | |
tgt_path = os.path.abspath(file_path) | |
with open(tgt_path, 'wb') as dat_f: | |
pickle.dump(dat, dat_f) | |
def io_load_from_pickle(file_path): | |
'''Simple wrapper function to save time when loading data as a pickled file.''' | |
import pickle | |
tgt_path = os.path.abspath(file_path) | |
with open(tgt_path, 'rb') as dat_f: | |
dat = pickle.load(dat_f) | |
return dat | |
def download_tps_sample_data(): | |
'''For testing and illustration, I've stuck some sample lat-lon data with a dependent variable on dropbox. | |
This function goes and gets that data and returns it as two numpy arrays suitable for testing and a basic | |
example.''' | |
lnk = 'https://www.dropbox.com/s/raw/m35fbcorigw7l95/test_spline_data.txt' | |
import urllib.request | |
dat = urllib.request.urlopen(lnk).readlines() | |
y_list = eval(dat[0].strip()) | |
X=np.array([eval(dat[1].strip()), eval(dat[2].strip())],dtype=np.float64).T | |
Y = np.array(y_list, dtype=np.float64) | |
return Y,X | |
def get_lat_lon_lists(lon_res, lat_res, NSWE_bounds): | |
'''Returns two lists of points, one for lat and one for lon, that are spaced evenly between the bounds of the map. | |
Specifically, the desired resolution is adjusted slightly downward to fit an integer number of steps exactly | |
on the map. The resulting actual resolutions are returned along with the lists. | |
NOTE: assumes that the east longitude is greater than the west longitude, which may not necessarily be true. | |
''' | |
#TODO: make it so this can handle when East < West due to stradling the international date line. | |
N = max(NSWE_bounds[0], NSWE_bounds[1]) | |
S = min(NSWE_bounds[0], NSWE_bounds[1]) | |
E = NSWE_bounds[3] | |
W = NSWE_bounds[2] | |
NSdist = (N-S) | |
EWdist = (E-W) if E>W else 360.-abs(E-W) #if bounds cross dateline then take complement | |
longitude_mod_func = lambda x: (x + 180.) % 360 - 180. | |
n_lat_levels = int(NSdist/lat_res) + 1 | |
if NSdist/lat_res==int(NSdist/lat_res): | |
n_lat_levels -= 1 | |
n_lon_levels = int(EWdist / lon_res) + 1 | |
if EWdist/lon_res==int(EWdist/lon_res): | |
n_lon_levels -= 1 | |
lat_res_adj = NSdist / n_lat_levels | |
lon_res_adj = EWdist / n_lon_levels | |
lat_list = list(map(lambda x: S + x*lat_res_adj, range(n_lat_levels+1))) | |
lon_list = list(map(lambda x: longitude_mod_func(W + x * lon_res_adj), range(n_lon_levels + 1))) | |
return lon_list, lat_list, lon_res_adj, lat_res_adj | |
def get_lons_in_Xrange_for_horizontal_band(X, center_lat, lat_halfwidth, lon_list, lon_res): | |
'''For a particular horizontal band on the map, return a grid of evenly spaced longitudes | |
contained within the range of the X-values in that band.''' | |
# TODO: figure out what to do if East-West map boundaires span the international date line. | |
minlat = center_lat - lat_halfwidth; | |
maxlat = center_lat + lat_halfwidth; | |
hband_x_positions = np.where((X[:, 2] >= minlat) & (X[:, 2] <= maxlat))[0] | |
minXlon = np.min(X[hband_x_positions, 1]) | |
maxXlon = np.max(X[hband_x_positions, 1]) | |
out_lons = list(filter(lambda x: (x > minXlon - lon_res) and (x < maxXlon + lon_res), lon_list)) | |
out_lons.sort() | |
return out_lons |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment