Skip to content

Instantly share code, notes, and snippets.

@luiarthur
Last active February 16, 2023 02:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save luiarthur/2cb248252e52372820b29cbdd755f57e to your computer and use it in GitHub Desktop.
Save luiarthur/2cb248252e52372820b29cbdd755f57e to your computer and use it in GitHub Desktop.
Kernel Regression
import numpy as np
from scipy.spatial import KDTree
import matplotlib.pyplot as plt
# Color map.
cmap = plt.get_cmap("bwr")
# True surface.
def f(X):
return np.cos(X[:, 0] * 10) + np.sin(X[:, 1] * 10) * 2
# Generate a grid.
x0, x1 = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 1, 100))
Xgrid = np.column_stack([x0.ravel(), x1.ravel()])
# Plot fine resolution of true surface.
plt.contourf(x0, x1, f(Xgrid).reshape(x0.shape), levels=101, cmap=cmap)
plt.colorbar()
plt.show()
# Generate some noisy data.
np.random.seed(0)
X = np.random.rand(100_000, 2)
sigma = 0.1
y = np.random.normal(f(X), sigma)
# Plot data.
plt.scatter(*X.T, c=y, cmap=cmap)
plt.show()
# Squared exponential kernel.
def sqexpkernel(X, Y, length_scale=1, process_sd=1):
diff = X[..., None] - np.swapaxes(Y[..., None], -1, -3)
sq_scaled_diff = (diff / length_scale) ** 2
return np.exp(-sq_scaled_diff.sum(-2) / 2) * (process_sd ** 2)
# Kernel regression for scalar output.
class KernelRegression1D():
def __init__(self, kernel):
self.kernel = kernel
self.X = None
self.Y = None
def fit(self, X, y):
self.X = X
self.y = y
self.kdtree = KDTree(self.X)
def predict(self, X, k):
nn_idx = self.kdtree.query(X, k=k)[1]
K = self.kernel(X[:, None, :], self.X[nn_idx]).squeeze()
W = K / K.sum(1)[:, None]
return (W * self.y[nn_idx]).sum(-1)
# Fit model.
kr = KernelRegression1D(lambda x, y: sqexpkernel(x, y, 1))
kr.fit(X, y)
# Predict.
pred = kr.predict(Xgrid, 15)
# Plot predictions on fine grid.
plt.contourf(x0, x1, pred.reshape(x0.shape), levels=101, cmap=cmap)
plt.colorbar()
plt.show()
# Predictions vs truth. Should be on diagonal line.
plt.scatter(f(Xgrid), pred)
plt.show()
def rmse(x, y):
return np.sqrt(((x - y) ** 2).mean())
print("RMSE: ", rmse(f(Xgrid), pred))
# %timeit kr.predict(Xgrid, 15)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment