Skip to content

Instantly share code, notes, and snippets.

@satra
Created April 1, 2010 12:25
Show Gist options
  • Save satra/351740 to your computer and use it in GitHub Desktop.
Save satra/351740 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
import mvpa.suite as ms
# create xor dataset
samples = np.random.rand(100,2)-0.5
targets = np.sign(samples)
targets = targets[:,0] == targets[:,1]
# display samples
idx = np.nonzero(targets==1)
plt.plot(samples[idx[0],0], samples[idx[0],1],'rs')
idx = np.nonzero(targets==0)
plt.plot(samples[idx[0],0], samples[idx[0],1],'bs')
# create dataset
ds = ms.dataset_wizard(samples, targets=targets, chunks=range(samples.shape[0]))
# choose classifier
clf = ms.kNN() # SMLR, LinearCSVMC, etc.,.
terr = ms.TransferError(clf)
# cross-validation with splitter
cvte = ms.CrossValidatedTransferError(terr, splitter=ms.HalfSplitter(),
enable_ca=['confusion'])
res = cvte(ds)
np.mean(res) # mean cv error
print cvte.ca.confusion
plt.figure()
cvte.ca.confusion.plot()
plt.show()
# different splitter - leave one out
cvte = ms.CrossValidatedTransferError(terr, splitter=ms.NFoldSplitter(),
enable_ca=['confusion'])
res = cvte(ds)
np.mean(res) # mean cv error
print cvte.ca.confusion
plt.figure()
cvte.ca.confusion.plot()
plt.show()
# cycle through all binary svm classifiers
for clf in ms.clfswh['binary','svm']:
terr = ms.TransferError(clf)
cvte = ms.CrossValidatedTransferError(terr,
splitter=ms.HalfSplitter(npertarget='equal',
nrunspersplit=10),
enable_ca=['confusion'])
try:
res = cvte(ds)
np.mean(res) # mean cv error
#print cvte.ca.confusion
#plt.figure()
#cvte.ca.confusion.plot()
#plt.suptitle(str(clf))
print str(clf), cvte.ca.confusion.error
print cvte.ca.confusion.matrix
if cvte.ca.confusion.error < error:
error = cvte.ca.confusion.error
bestclf = clf.clone()
except:
print "could not run classifier: %s" % str(clf)
terr = ms.TransferError(bestclf)
cvte = ms.CrossValidatedTransferError(terr,
splitter=ms.HalfSplitter(npertarget='equal',
nrunspersplit=10),
enable_ca=['confusion'])
res = cvte(ds)
np.mean(res) # mean cv error
print cvte.ca.confusion
plt.figure()
cvte.ca.confusion.plot(numbers=True)
plt.suptitle(cvte.transerror.clf)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment