Last active
October 30, 2017 18:53
-
-
Save johannah/3299af8fe7183d26dbf2b68e1c936b17 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import time | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from copy import deepcopy | |
import os, sys | |
from scipy.misc import face | |
import GPy | |
from skimage.transform import resize | |
from skimage.filters import gaussian_filter | |
from sklearn.cluster import MiniBatchKMeans, KMeans | |
from sklearn.preprocessing import StandardScaler | |
class FindDataDist(): | |
def __init__(self, ysize=100, xsize=140, gtdata=None, do_plot=False): | |
self.ysize, self.xsize = ysize, xsize | |
self.do_plot = do_plot | |
self.gtdata = gtdata | |
if self.gtdata is None: | |
self.gtdata = np.zeros((self.ysize, self.xsize), dtype=np.float32) | |
self.sam_locs = np.array([[0,0]]) | |
self.sams = np.array([[0]]) | |
self.rng = np.random.RandomState(199) | |
self.sdata = np.ones_like(self.gtdata)*-1.0 | |
self.pdata = deepcopy(self.sdata) | |
self.kdata = np.zeros_like(self.sdata) | |
self.vdata = np.zeros_like(self.gtdata) | |
self.num_new_samples_needed = 50 | |
self.max_num_samples = 400 | |
self.ker = GPy.kern.Matern52(2,ARD=True) + GPy.kern.White(2) | |
self.ker = GPy.kern.Exponential(2, ARD=False, variance=1, lengthscale=.2) | |
self.ker += GPy.kern.White(2) | |
self.uncertain = [([0,0],0)] | |
# initialize sams | |
#kmeans = MiniBatchKMeans(n_clusters=max_num_samples, random_state=rng, verbose=False) | |
self.kmeans = KMeans(n_clusters=self.max_num_samples, random_state=self.rng, n_jobs=-1, verbose=False) | |
self.vkmeans = KMeans(n_clusters=4, random_state=self.rng, n_jobs=-1, verbose=False) | |
if self.do_plot: | |
self.setup_plot() | |
def resample_duplicates(self, locs): | |
# find duplicate sample locations | |
sorted_data = locs[np.lexsort(locs.T),:] | |
row_mask = np.append([True],np.any(np.diff(sorted_data,axis=0),1)).ravel() | |
set_locs = sorted_data[row_mask] | |
def select_optimal_points(self, locs, sams, kk): | |
# want to select points which will decrease model_var | |
# combine all three data sources (yindex, xindex, sampled value) | |
# expects data to be normalized already | |
ks = time.time() | |
X = np.hstack((locs,sams)).astype(np.float32) | |
Xmean = X.mean() | |
Xstd = X.std() | |
Xscaled = (X-Xmean)/Xstd | |
print("using kmeans on samples of size (%s,%s)" %(Xscaled.shape[0], Xscaled.shape[1])) | |
y_pred = kk.fit_predict(Xscaled) | |
clust_cent = (Xstd * kk.cluster_centers_)+Xmean | |
# centroid location indices | |
X = np.rint(clust_cent[:,:2]).astype(np.int) | |
# centroid samples | |
y = clust_cent[:,2] | |
y = y.reshape(y.shape[0], 1) | |
ke = time.time() | |
print("Spent %s seconds on kmeans" %(ke-ks)) | |
return X,y | |
def setup_plot(self): | |
if self.do_plot: | |
self.f,self.ax = plt.subplots(1,5,sharey=True, figsize=(14,5)) | |
self.ax0 = self.ax[0].imshow(self.gtdata, vmin=0, vmax=1, interpolation="None", origin='lower') | |
self.ax[0].set_title("ground truth") | |
self.ax[1].set_title("No samples") | |
self.ax[2].set_title("No prediction") | |
self.ax[3].set_title('KMeans Points') | |
self.ax[4].set_title('Variance') | |
self.ax1 = self.ax[1].imshow(self.sdata, vmin=0, vmax=1, interpolation="None", origin='lower') | |
self.ax2 = self.ax[2].imshow(self.pdata, vmin=0, vmax=1, interpolation="None", origin='lower') | |
self.ax3 = self.ax[3].imshow(self.kdata, vmin=0, vmax=1, interpolation="None", origin='lower') | |
self.ax4 = self.ax[4].imshow(self.vdata, vmin=0, vmax=.1, interpolation="None", origin='lower') | |
plt.show(block=False) | |
def plot(self, pred_locs, pred_sam, pred_var): | |
print("Plotting points") | |
if self.do_plot: | |
self.pdata[pred_locs[:,0], pred_locs[:,1]] = pred_sam.ravel() | |
self.pdata[self.sam_locs[:,0], self.sam_locs[:,1]] = self.sams.ravel() | |
self.vdata[pred_locs[:,0], pred_locs[:,1]] = pred_var.ravel() | |
#vdata[fy,fx] = 1 | |
# make image for plotting | |
self.ax1.set_data(self.sdata) | |
self.ax2.set_data(self.kdata) | |
self.ax3.set_data(self.pdata) | |
self.ax4.set_data(self.vdata) | |
self.ax[1].set_title("Sampled %s" %(self.sam_locs.shape[0])) | |
self.ax[2].set_title("%s kmeans" %self.max_num_samples) | |
self.ax[3].set_title("%s preds" %pred_locs.shape[0]) | |
self.f.canvas.draw() | |
def update_gp(self, new_locs, new_sams): | |
# add new to old samples | |
self.sam_locs = np.vstack((new_locs, self.sam_locs)) | |
self.sams = np.vstack((new_sams, self.sams)) | |
self.sdata[ysams,xsams] = new_sams.ravel() | |
# update our sample array | |
num_samples = self.sam_locs.shape[0] | |
#time.sleep(1) | |
self.num_samples_last_gp = num_samples | |
if num_samples > self.max_num_samples: | |
# X,y are the points actually used in the gp | |
X,y = self.select_optimal_points(self.sam_locs, self.sams, self.kmeans) | |
self.kdata = np.zeros_like(self.sdata) | |
self.kdata[X[:,0], X[:,1]] = y.ravel() | |
else: | |
X,y = self.sam_locs, self.sams | |
print("Starting new GP", X.shape, y.shape) | |
# only predict locations that we do not have data for | |
pred_locs = np.asarray(np.where(self.sdata<0)).T | |
# predict | |
try: | |
st = time.time() | |
print("Building gp with %s points from %s samples" %(X.shape[0], self.sam_locs.shape[0])) | |
m = GPy.models.GPRegression(X, y, self.ker) | |
m.kern.lengthscale = 10 | |
m.constrain_positive('*') | |
m.optimize('tnc', max_f_eval=1000) | |
bt = time.time() | |
print("build took %s secs" %(bt-st)) | |
print("Predicting %s points" %pred_locs.shape[0]) | |
pred_sam, pred_var = m.predict(pred_locs) | |
et = time.time() | |
print("predict took %s secs" %(et-bt)) | |
print("total took %s secs" %(et-st)) | |
hvar = pred_var[:,0]>.001 | |
#from IPython import embed; embed() | |
NX, ny = self.select_optimal_points(pred_locs[hvar], pred_var[hvar], self.vkmeans) | |
self.uncertain = zip(NX,ny) | |
# find points with most variance - order highest variance first | |
self.uncertain.sort(key=lambda x: x[1], reverse=True) | |
print("found uncertain points", self.uncertain) | |
self.plot(pred_locs, pred_sam, pred_var) | |
except Exception, e: | |
print("exception encountered %s" %e) | |
#if self.do_plot: | |
# self.ax[0].set_title("sampling finished") | |
# plt.show(block=True) | |
if __name__ == '__main__': | |
noise_variance = .01 | |
# how many samples were used in the last gp (used to determine when we should update | |
num_samples_last_gp = 0 | |
ysize,xsize = 200,150 | |
gtdata = gaussian_filter(resize(face()[:,:,0], (ysize,xsize)), 5).astype(np.float32)+.0001 | |
fk = FindDataDist(ysize=ysize, xsize=xsize, gtdata=gtdata, do_plot=True) | |
for i in range(6): | |
plt.pause(.0001) | |
print('----------------------------------') | |
add_num_samples = np.random.randint(10,400) | |
print("adding %s new samples" %add_num_samples) | |
ysams = np.random.random_integers(0, ysize-1, add_num_samples) | |
xsams = np.random.random_integers(0, xsize-1, add_num_samples) | |
# add the new samples with noise | |
new_locs = np.vstack((ysams,xsams)).T | |
v = np.random.normal(loc=0, scale=noise_variance, size=(new_locs.shape[0], 1)) | |
# sample new data with noise | |
new_sams = gtdata[ysams,xsams].reshape(new_locs.shape[0], 1) + v | |
new_sams[new_sams<0] = 0. | |
new_sams[new_sams>1] = 1. | |
fk.update_gp(new_locs, new_sams) | |
#m.plot() | |
#m.plot_f() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment