Skip to content

Instantly share code, notes, and snippets.

@pckujawa
Created December 15, 2013 20:17
Show Gist options
  • Save pckujawa/7977590 to your computer and use it in GitHub Desktop.
Save pckujawa/7977590 to your computer and use it in GitHub Desktop.
Simple audio classifier (speech vs music) using scikit-learn (Naive Bayes classifier). Made for Multimedia Processing course.
""" usage:
a4.py train TRAIN_FEATURE_FILE [--new] [--validate]
a4.py classify MUSIC_FEATURE_FILE
The files should be CSV with 6 columns, the last of which is the target/label/class (or empty, if classifying), and the first of which is ignored.
"""
#-------------------------------------------------------------------------------
# Name: Pat Kujawa
# Purpose: MM audio classification asn 4
#-------------------------------------------------------------------------------
from __future__ import division
import os, sys
import docopt
import numpy as np
import cPickle as pickle
from sklearn.naive_bayes import GaussianNB
from sklearn import cross_validation
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.cross_validation import cross_val_score
picklePath = r"classifier.pickle"
target_names = ['speech', 'music'] # false, true
def preProcess(csvFile, classifying=False):
"""Returns (data, targets) where targets is bool array repr IsMusic.
"""
## csvFile = r"C:\Users\Pat\Dropbox\UM Grad School\2013 Fall\Multimedia MM processing 578\asn4-audio-classifier\energy,zc,zcr,centroid,bw,name,ismusic.csv"
## datatable = np.genfromtxt(csvFile, delimiter=',', names=True, dtype=None)
## featureTable = datatable[sorted(list(set(datatable.dtype.names) - {'zc', 'ismusic', 'name'}))] # use zero crossing rate instead of absolute count; ditch non-feature data
## classifications = datatable['ismusic'] # bool
names = np.genfromtxt(csvFile, delimiter=',', usecols=(0), dtype=str)
data = np.genfromtxt(csvFile, delimiter=',', usecols=(1,2,3,4))
if classifying:
targets = None
else:
targets = np.genfromtxt(csvFile, delimiter=',', usecols=(5), dtype=bool) # bool ismusic
return data, names, targets
def train(data, names, targets, startNew=False, cv=False):
"""Create and serialize a classifier trained on 2/3 of the input data.
:param startNew: create a new classifier if true else add to the training of the previous classifier
:param cv: do cross-validation with a subset of items
"""
classifier = None
if not startNew:
try:
with open(picklePath, 'rb') as f:
classifier = pickle.load(f)
except:
sys.stderr.write("Couldn't deserialize classifier. Creating a new one instead \n")
t = targets
# From DZone.com refcard: Data Mining - Discovering and Visualizing Patterns with Python by Giuseppe Vettigli
classifier = classifier or GaussianNB()
if not cv:
classifier.fit(data, t) # training
print 'Trained on all files:', ','.join(names)
return ''
##from sklearn import svm
##classifier = svm.SVC() # classifying all as Speech
# t_ means target, as in expected/desired classification
train, test, t_train, t_test, trainFiles, testFiles = \
cross_validation.train_test_split(data, t, names, test_size=0.33)
# show which files are used for train/test
print 'Training files:', ','.join(trainFiles)
## print sum((s.startswith("mu") for s in trainFiles)), 'music files /', len(trainFiles)
print 'Test files:', ','.join(testFiles)
classifier.fit(train, t_train) # train
print 'Prior probabilities (n={}):'.format(len(trainFiles))
for cls, prob in zip(classifier.classes_, classifier.class_prior_):
print target_names[cls], prob
print "Accuracy for 2/3 training, 1/3 test:"
print classifier.score(test, t_test) # test
# 0.0625 :(
print "Confusion matrix for 2/3 training, 1/3 test:"
print confusion_matrix(classifier.predict(test), t_test)
##[[2 2]
## [4 8]]
print 'Classification report for 2/3 training, 1/3 test:'
print classification_report(classifier.predict(test),
t_test, target_names=target_names)
print 'leave one out cv'
# cross validation with leave one out
# http://stackoverflow.com/questions/17499068/train-scikit-svm-customize-score-assessment
scores = cross_val_score(classifier, data, t,
cv=cross_validation.LeaveOneOut(len(t)))
print scores, np.sum(scores), '/', len(scores), '=', np.mean(scores)
try:
with open(picklePath, 'wb') as f:
pickle.dump(classifier, f, protocol=pickle.HIGHEST_PROTOCOL)
except:
sys.stderr.write("Error persisting classifier to file. Are you in a protected directory\n")
globals().update(locals())
return ''
def classify(data):
"""Predict the class of the data from a deserialized classifier.
"""
assert data.ndim == 1
try:
with open(picklePath, 'rb') as f:
classifier = pickle.load(f)
except:
sys.stderr.write("Error: no classifier found. Need to train first.\n")
return
result = classifier.predict(data)
## print result # seems to be an array of true/false
globals().update(locals())
return target_names[result[0]]
def main():
## print sys.argv
args = docopt.docopt(__doc__, options_first=False)
## print args
if args['train']:
print train(*preProcess(args['TRAIN_FEATURE_FILE']), # un-tuple args
startNew=args['--new'], cv=args['--validate'])
elif args['classify']:
print classify(preProcess(args['MUSIC_FEATURE_FILE'])[0])
if __name__ == '__main__':
main()
sp4.wav,0.106326512992382,0.1419375,600.216506027543,1000,
mu1.wav,0.0135809620842338,0.245192307692308,4090.59341438305,8000,True
mu2.wav,0.00968006905168295,0.0965171330802044,3416.33785202295,7781.25,True
mu3.wav,0.00785261858254671,0.0928920764386943,1335.16906731851,2781.25,True
mu4.wav,0.00693985680118203,0.114189284207566,1938.03139461655,3218.75,True
mu5.wav,0.000804243725724518,0.100921875,2862.56607992312,7593.75,True
mu6.wav,0.0110080037266016,0.0568290129533274,1752.79206161524,3968.75,True
mu7.wav,0.000471108447527513,0.128861388459195,1398.09072062083,4593.75,True
mu8.wav,0.000866658810991794,0.206387362637363,3057.29499579522,5218.75,True
mu9.wav,0.00222460692748427,0.0922671078921079,2056.32778945909,6687.5,True
mu10.wav,0.00612919591367245,0.124267566680729,749.351732128472,750,True
mu11.wav,0.00639878120273352,0.0616570929070929,1104.58139795342,2312.5,True
mu12.wav,0.0103937992826104,0.235546875,2952.09057466559,6875,True
mu13.wav,0.0168704781681299,0.0681885654463351,1508.61129664485,6093.75,True
mu14.wav,0.0366614460945129,0.0399224987890436,609.916803729763,1437.5,True
mu15.wav,0.00200785440392792,0.111,967.358939286494,2125,True
mu16.wav,0.0172653328627348,0.13803125,2137.19411839309,4750,True
mu17.wav,0.0767187625169754,0.09871875,3026.17988619235,7812.5,True
mu18.wav,0.00394412688910961,0.13559375,2234.4520504372,5500,True
mu19.wav,0.00181142438668758,0.16890625,1762.25736668424,4000,True
mu20.wav,0.000672354304697365,0.12184375,2408.25157117664,5687.5,True
sp1.wav,0.000528156640939415,0.15546875,1030.89509309957,1656.25,False
sp2.wav,0.000454288354376331,0.1216875,1226.68622924213,4093.75,False
sp3.wav,0.000502355920616537,0.13159375,1491.06108518104,3062.5,False
sp4.wav,0.106326512992382,0.1419375,600.216506027543,1000,False
sp5.wav,0.000719425734132528,0.1834375,1203.41159032112,1906.25,False
sp6.wav,0.000317843368975446,0.0578484015984016,615.617688734926,1343.75,False
sp7.wav,0.00830729119479656,0.08121875,1311.42313711464,3281.25,False
sp8.wav,0.0164563357830048,0.073848026973027,433.017009278009,937.5,False
sp9.wav,0.150312662124634,0.103109375,534.844251956482,468.75,False
sp10.wav,0.139020338654518,0.143109375,612.676843957936,1062.5,False
sp11.wav,0.000554115045815706,0.135296875,1498.60175536014,2062.5,False
sp12.wav,0.00710451928898692,0.172734375,4233.96991563503,7968.75,False
sp13.wav,0.177054643630981,0.09575,4740.70116489936,8000,False
sp14.wav,0.00660737184807658,0.129464285714286,861.62212786018,4031.25,False
sp15.wav,0.00381537573412061,0.1440625,1136.02801116168,2250,False
sp16.wav,0.035685483366251,0.203984375,2065.26130383625,5375,False
sp17.wav,0.061883706599474,0.20421875,4407.30136680859,7656.25,False
sp18.wav,0.00359855033457279,0.0605176073926074,944.371410860124,3187.5,False
sp19.wav,0.00341002829372883,0.115993381618382,1400.15423227314,1875,False
sp20.wav,0.0158364344388247,0.0796391108891109,625.714886130722,1312.5,False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment