Created
May 28, 2017 13:26
-
-
Save emakryo/8b7383c44871612eb8680fda000c04b8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import necessary packages | |
import os | |
import numpy as np | |
import pandas as pd | |
from sklearn.svm import SVC | |
from sklearn.model_selection import GridSearchCV | |
from sklearn import metrics | |
def load_data(): | |
""" load data from CSV formats """ | |
path = 'data/real_signal' | |
train = os.path.join(path, 'train_data_svm.txt') | |
test = os.path.join(path, 'test_data_svm.txt') | |
train_class = os.path.join(path, 'train_class_svm.txt') | |
test_class = os.path.join(path, 'test_class_svm2.txt') | |
Xtr = np.array(pd.read_csv(train, header=None)) | |
Xte = np.array(pd.read_csv(test, header=None)) | |
ytr = np.array(pd.read_csv(train_class, header=None)) | |
yte = np.array(pd.read_csv(test_class, header=None)) | |
return Xtr, Xte, ytr.reshape(-1), yte.reshape(-1) | |
def normal_svm(): | |
# candidates of hyperparameters for grid search | |
param_grid = {'C': [2**i for i in range(5, 10)], | |
'gamma': [2**i for i in range(5, 10)]} | |
Xtr, Xte, ytr, yte = load_data() # load data | |
# declare a model | |
model = GridSearchCV(SVC(), param_grid=param_grid, n_jobs=-1, cv=10) | |
model.fit(Xtr, ytr) # fit the training data | |
# calculate decision function values and prediction | |
score = model.decision_function(Xte) | |
ypred = model.predict(Xte) | |
# save the best hyperparameter and several performance metrics in CSV | |
result = pd.DataFrame(columns=['C', 'gamma', 'accuracy', 'AUC', 'precision', | |
'recall', 'F1', 'Nsv']) | |
C = model.best_params_['C'] | |
gamma = model.best_params_['gamma'] | |
acc = metrics.accuracy_score(yte, ypred) | |
auc = metrics.roc_auc_score(yte, score) | |
prec = metrics.precision_score(yte, ypred) | |
rec = metrics.recall_score(yte, ypred) | |
f1 = metrics.f1_score(yte, ypred) | |
nsv = sum(model.best_estimator_.n_support_) | |
result.loc['Normal'] = [C, gamma, acc, auc, prec, rec, f1, nsv] | |
print(result) | |
result.to_csv('result.csv') | |
# save the detailed results of cross validation | |
cv_result = pd.DataFrame(model.cv_results_) | |
cv_result.to_csv('cv_results.csv') | |
if __name__ == "__main__": | |
normal_svm() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment