Created
March 28, 2015 08:38
-
-
Save nkt1546789/be1e02e715d3bc462fcc to your computer and use it in GitHub Desktop.
Implementation of weighted kNN.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
from numpy import random | |
from matplotlib import pyplot | |
from sklearn import datasets | |
from sklearn import base | |
from sklearn import metrics | |
from sklearn import grid_search | |
class NonParametricKNN(base.BaseEstimator): | |
def __init__(self,sigma=0.1): | |
self.sigma=sigma | |
def fit(self,X,y): | |
self.X=X; self.y=y; self.n=len(y); self.ul=numpy.unique(self.y); self.c=len(self.ul); | |
return self | |
def predict(self,X): | |
Phi=metrics.pairwise_kernels(X,self.X,metric="rbf",gamma=1./(2*self.sigma**2)) | |
Y=numpy.kron(numpy.c_[self.y],numpy.ones((1,self.c)))==numpy.kron(self.ul,numpy.ones((self.n,1))) | |
return numpy.argmax(Phi.dot(Y),1) | |
def score(self,X,y): | |
return sum(self.predict(X)==y)/float(len(y)) | |
if __name__ == '__main__': | |
n=300; ntr=int(0.7*n) | |
X,y=datasets.make_moons(n_samples=n,noise=0.05) | |
X,y=datasets.make_circles(n_samples=n,noise=0.05) | |
idx=random.permutation(n); itr=idx[:ntr]; ite=idx[ntr:] | |
params={"sigma":numpy.logspace(-1,1,10)} | |
gs=grid_search.GridSearchCV(NonParametricKNN(),params).fit(X[itr],y[itr]) | |
model=gs.best_estimator_ | |
print "best CV score:", gs.best_score_ | |
print "best parameters:", gs.best_params_ | |
print "classification accuracy:",sum(y[ite]==model.predict(X[ite]))/float(len(y[ite])) | |
ypred=model.predict(X) | |
pyplot.scatter(X[:,0],X[:,1],c=ypred,marker="o",cmap=pyplot.cm.winter) | |
pyplot.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
We don't have to choose "k". Although we have to select the bandwidth "sigma", it can be done by cross-validation.