Instantly share code, notes, and snippets.

# nkt1546789/rbfmodel_wrapper.py Created Aug 24, 2015

What would you like to do?
 import numpy as np class RbfModelWrapper(object): def __init__(self,model,gamma=1.,**kwds): self._model=model self.gamma=gamma def fit(self,X,y): X2=np.c_[np.sum(X**2,1)] Phi=np.exp(-self.gamma*(X2+X2.T-2*X.dot(X.T))) self._model.fit(Phi,y) self.X=X; self.X2=X2; return self def predict(self,X): X2=np.c_[np.sum(X**2,1)] Phi=np.exp(-self.gamma*(X2+self.X2.T-2*X.dot(self.X.T))) return self._model.predict(Phi) def score(self,X,y): X2=np.c_[np.sum(X**2,1)] Phi=np.exp(-self.gamma*(X2+self.X2.T-2*X.dot(self.X.T))) return self._model.score(Phi,y) def get_params(self,deep=True): params=self._model.get_params(deep=deep) params.setdefault("gamma",self.gamma) params.setdefault("model",self._model) return params def set_params(self, **params): params.pop("gamma") self._model.set_params(**params) return self if __name__=="__main__": from numpy import random from sklearn.linear_model import LogisticRegression import matplotlib.pyplot as plt random.seed(1) n1=100; n2=500; n=n1+n2; mu1=[0,0]; mu2=[2,0]; Sigma1=0.1*np.identity(2); Sigma2=0.5*np.identity(2); X=np.r_[random.multivariate_normal(mu1,Sigma1,n1), random.multivariate_normal(mu2,Sigma2,n2)] y=np.concatenate([np.repeat(1,n1),np.repeat(0,n2)]) idx=random.permutation(n); X=X[idx]; y=y[idx]; ntr=np.int32(n*0.7) itr=idx[:ntr] ite=idx[ntr:] from sklearn.grid_search import GridSearchCV from sklearn import metrics gs=GridSearchCV(RbfModelWrapper(LogisticRegression()),param_grid={"gamma":np.logspace(-2,0,9)}).fit(X[itr],y[itr]) print gs.best_score_ print gs.best_params_ clf=gs.best_estimator_ print clf.score(X[ite],y[ite]) offset=.5 xx,yy=np.meshgrid(np.linspace(X[:,0].min()-offset,X[:,0].max()+offset,300), np.linspace(X[:,1].min()-offset,X[:,1].max()+offset,300)) Z=clf.predict(np.c_[xx.ravel(),yy.ravel()]) Z=Z.reshape(xx.shape) a=plt.contour(xx, yy, Z, levels=[0.5], linewidths=2, colors='green') b1=plt.scatter(X[y==1][:,0],X[y==1][:,1],c="blue",s=50) b2=plt.scatter(X[y==0][:,0],X[y==0][:,1],c="red",s=50) plt.axis("tight") plt.xlim((X[:,0].min()-offset,X[:,0].max()+offset)) plt.ylim((X[:,1].min()-offset,X[:,1].max()+offset)) plt.legend([a.collections[0],b1,b2], [r"p(y|x)=0.5","positive","unlabeled"], prop={"size":10}) plt.tight_layout() plt.show()