Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Local regression
"""Local regression"""
# Author: Mathieu Blondel <mathieu@mblondel.org>
# License: BSD 3 clause
import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.linear_model import Ridge
class LocalRegression(BaseEstimator, RegressorMixin):
"""
Local regression
Reference
---------
Elements of Stastical Learning, Section 6.3.
"""
def __init__(self, alpha=1, kernel="linear", gamma=None, degree=3, coef0=1,
kernel_params=None):
self.alpha = alpha
self.kernel = kernel
self.gamma = gamma
self.degree = degree
self.coef0 = coef0
self.kernel_params = kernel_params
@property
def _pairwise(self):
return self.kernel == "precomputed"
def _get_kernel(self, X, Y=None):
if callable(self.kernel):
params = self.kernel_params or {}
else:
params = {"gamma": self.gamma,
"degree": self.degree,
"coef0": self.coef0}
return pairwise_kernels(X, Y, metric=self.kernel,
filter_params=True, **params)
def fit(self, X, y):
self.X_train_ = X
self.y_train_ = y
return self
def predict(self, X):
n_samples = X.shape[0]
K = self._get_kernel(X, self.X_train_)
pred = np.zeros(n_samples, dtype=np.float64)
for i in xrange(n_samples):
reg = Ridge(alpha=self.alpha, fit_intercept=True)
reg.fit(self.X_train_, self.y_train_, sample_weight=K[i])
pred[i] = reg.predict(X[i].reshape(1, -1))[0]
return pred
if __name__ == '__main__':
import matplotlib.pyplot as plt
def f(x):
""" function to approximate by local regression"""
return x * np.sin(x)
# Generate points used to plot.
x_plot = np.linspace(0, 10, 100)
# Generate points and keep a subset of them.
x = np.linspace(0, 10, 100)
rng = np.random.RandomState(0)
rng.shuffle(x)
x = np.sort(x[:20])
y = f(x)
# Create matrix versions of these arrays.
X = x[:, np.newaxis]
X_plot = x_plot[:, np.newaxis]
plt.plot(x_plot, f(x_plot), label="ground truth")
plt.scatter(x, y, label="training points")
for gamma in (0.1, 0.5, 1.0, 3.0, 5.0):
lr = LocalRegression(kernel="rbf", gamma=gamma)
lr.fit(X, y)
y_plot = lr.predict(X_plot)
plt.plot(x_plot, y_plot, label="gamma=%0.1f" % gamma)
plt.legend(loc='lower left')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment