Skip to content

Instantly share code, notes, and snippets.

@emakryo
Created February 9, 2017 05:41
Show Gist options
  • Save emakryo/a30863edffabb07272ed928e65f41fbb to your computer and use it in GitHub Desktop.
Save emakryo/a30863edffabb07272ed928e65f41fbb to your computer and use it in GitHub Desktop.
import struct
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
def parse(filename):
"""Parse data file created by getdata"""
with open(filename, "r") as f:
raw_data = [int(line.split()[1], 16) for line in f]
signal = [[], [], [], []]
signal_data = [[], [], [], []]
integ_data = [[], [], [], []]
svm_data = [[], [], [], []]
sort_data = [[], [], [], []]
for x in raw_data:
dtype = x >> 30
ch = (x >> 28) & 3
if dtype == 0:
if (x >> 27) & 1:
signal[ch] = []
signal[ch].append(x & 0x3fff)
if (x >> 26) & 1:
signal_data[ch].append(signal[ch])
elif dtype == 1:
integ_data[ch].append(x & 0x3ffffff)
elif dtype == 2:
fl = struct.unpack('f', struct.pack('I', (x << 4) & 0xffffffff))[0]
svm_data[ch].append(fl)
elif dtype == 3:
sort_data[ch].append(x & 0x1ffffff)
return signal_data, integ_data, svm_data, sort_data
def normalize(X):
mu = X.mean(axis=0)
sigma = X.std(axis=0)
return (X-mu)/sigma, mu, sigma
def train(pos, neg):
pos = np.array(pos)
neg = np.array(neg)
assert len(pos.shape) == 2 and len(neg.shape) == 2
assert pos.shape[1] == neg.shape[1]
param_grid = {'C': [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50,
100, 200, 500, 1000, 2000, 5000]}
model = GridSearchCV(SVC(), param_grid=param_grid, n_jobs=-1,
cv=min(10, min(len(pos), len(neg))))
X = np.concatenate([pos, neg])
y = np.array([1] * len(pos) + [0] * len(neg))
model.fit(X, y)
return model.best_estimator_
def write(estimator, mu, sigma):
alpha = estimator.dual_coef_
if type(estimator.gamma) == float:
K = estimator.gamma
elif estimator.gamma == 'auto':
K = estimator.support_vectors_.shape[1]
else:
raise Exception('unknown gamma parameter')
b = estimator.intercept_
np.savetxt('alpah.csv', alpha.reshape(1, -1), delimiter=',')
np.savetxt('k.csv', K.reshape(1, -1), delimiter=',')
np.savetxt('b.csv', b.reshape(1, -1), delimiter=',')
np.savetxt('mu.csv', mu.reshape(1, -1), delimiter=',')
np.savetxt('sigma.csv', sigma.reshape(1, -1), delimiter=',')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment