Skip to content

Instantly share code, notes, and snippets.

@tsutarou10
Last active September 22, 2017 19:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tsutarou10/93044e5c89c87e0d10fdc46468515d23 to your computer and use it in GitHub Desktop.
Save tsutarou10/93044e5c89c87e0d10fdc46468515d23 to your computer and use it in GitHub Desktop.
交差検定
#coding: utf-8
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC,SVC
import numpy as np
from sklearn.metrics import recall_score,precision_score,f1_score
from sklearn import datasets
class SVM:
def svm(self):
iris = datasets.load_iris()
features = iris.data
target = iris.target
mask = np.arange(features.shape[0])
mask = np.random.permutation(mask)
scores1 = []
scores2 = []
scores3 = []
S = 5
mask = mask.reshape(S,mask.shape[0]/S) #make a mask
for s in range(S):
trX = features[mask[s]]
trY = target[mask[s]]
teX = features[np.setdiff1d(mask,mask[s])]
teY = target[np.setdiff1d(mask,mask[s])]
clf = OneVsRestClassifier(LinearSVC())
clf = clf.fit(trX,trY)
pred = clf.predict(teX)
score1 = precision_score(teY,pred,average = "micro")
score2 = recall_score(teY,pred,average = "micro")
score3 = f1_score(teY,pred,average = "micro")
scores1.append(score1)
scores2.append(score2)
scores3.append(score3)
print "micro precision : %.2f" % np.array(scores1).mean()
print "micro recall : %.2f" % np.array(scores2).mean()
print "micro F1 : %.2f" % np.array(scores3).mean()
if __name__ == "__main__":
s = SVM()
s.svm()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment