Skip to content

Instantly share code, notes, and snippets.



Created Feb 14, 2017
What would you like to do?
Code for the training the SVM classifier
import os
import sklearn
from sklearn import cross_validation, grid_search
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.svm import SVC
from sklearn.externals import joblib
def train_svm_classifer(features, labels, model_output_path):
train_svm_classifer will train a SVM, saved the trained and SVM model and
report the classification performance
features: array of input features
labels: array of labels associated with the input features
model_output_path: path for storing the trained svm model
# save 20% of data for performance evaluation
X_train, X_test, y_train, y_test = cross_validation.train_test_split(features, labels, test_size=0.2)
param = [
"kernel": ["linear"],
"C": [1, 10, 100, 1000]
"kernel": ["rbf"],
"C": [1, 10, 100, 1000],
"gamma": [1e-2, 1e-3, 1e-4, 1e-5]
# request probability estimation
svm = SVC(probability=True)
# 10-fold cross validation, use 4 thread as each fold and each parameter set can be train in parallel
clf = grid_search.GridSearchCV(svm, param,
cv=10, n_jobs=4, verbose=3), y_train)
if os.path.exists(model_output_path):
joblib.dump(clf.best_estimator_, model_output_path)
print("Cannot save trained svm model to {0}.".format(model_output_path))
print("\nBest parameters set:")
print("\nConfusion matrix:")
print("Labels: {0}\n".format(",".join(labels)))
print(confusion_matrix(y_test, y_predict, labels=labels))
print("\nClassification report:")
print(classification_report(y_test, y_predict))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment