Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
import numpy as np
class RbfModelWrapper(object):
def __init__(self,model,gamma=1.,**kwds):
self._model=model
self.gamma=gamma
def fit(self,X,y):
X2=np.c_[np.sum(X**2,1)]
Phi=np.exp(-self.gamma*(X2+X2.T-2*X.dot(X.T)))
self._model.fit(Phi,y)
self.X=X; self.X2=X2;
return self
def predict(self,X):
X2=np.c_[np.sum(X**2,1)]
Phi=np.exp(-self.gamma*(X2+self.X2.T-2*X.dot(self.X.T)))
return self._model.predict(Phi)
def score(self,X,y):
X2=np.c_[np.sum(X**2,1)]
Phi=np.exp(-self.gamma*(X2+self.X2.T-2*X.dot(self.X.T)))
return self._model.score(Phi,y)
def get_params(self,deep=True):
params=self._model.get_params(deep=deep)
params.setdefault("gamma",self.gamma)
params.setdefault("model",self._model)
return params
def set_params(self, **params):
params.pop("gamma")
self._model.set_params(**params)
return self
if __name__=="__main__":
from numpy import random
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
random.seed(1)
n1=100; n2=500; n=n1+n2;
mu1=[0,0]; mu2=[2,0]; Sigma1=0.1*np.identity(2); Sigma2=0.5*np.identity(2);
X=np.r_[random.multivariate_normal(mu1,Sigma1,n1),
random.multivariate_normal(mu2,Sigma2,n2)]
y=np.concatenate([np.repeat(1,n1),np.repeat(0,n2)])
idx=random.permutation(n); X=X[idx]; y=y[idx];
ntr=np.int32(n*0.7)
itr=idx[:ntr]
ite=idx[ntr:]
from sklearn.grid_search import GridSearchCV
from sklearn import metrics
gs=GridSearchCV(RbfModelWrapper(LogisticRegression()),param_grid={"gamma":np.logspace(-2,0,9)}).fit(X[itr],y[itr])
print gs.best_score_
print gs.best_params_
clf=gs.best_estimator_
print clf.score(X[ite],y[ite])
offset=.5
xx,yy=np.meshgrid(np.linspace(X[:,0].min()-offset,X[:,0].max()+offset,300),
np.linspace(X[:,1].min()-offset,X[:,1].max()+offset,300))
Z=clf.predict(np.c_[xx.ravel(),yy.ravel()])
Z=Z.reshape(xx.shape)
a=plt.contour(xx, yy, Z, levels=[0.5], linewidths=2, colors='green')
b1=plt.scatter(X[y==1][:,0],X[y==1][:,1],c="blue",s=50)
b2=plt.scatter(X[y==0][:,0],X[y==0][:,1],c="red",s=50)
plt.axis("tight")
plt.xlim((X[:,0].min()-offset,X[:,0].max()+offset))
plt.ylim((X[:,1].min()-offset,X[:,1].max()+offset))
plt.legend([a.collections[0],b1,b2],
[r"p(y|x)=0.5","positive","unlabeled"],
prop={"size":10})
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment