Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
código p/ treinar classificador SVM p/ CATMAT
import os
import pickle
from sklearn.utils import shuffle
from sklearn import linear_model
from sklearn import cross_validation
# carrega X
with open('X.pkl', mode = 'rb') as fbuffer:
X = pickle.load(fbuffer)
# carrega Y
with open('Y.pkl', mode = 'rb') as fbuffer:
Y = pickle.load(fbuffer)
# randomiza a ordem das observacoes
X, Y = shuffle(X, Y, random_state = 0)
# separa dados de treinamento e dados de validacao
Xtrain, Xvalid, Ytrain, Yvalid = cross_validation.train_test_split(X,
Y,
test_size = 0.3,
random_state = 0)
# separar dados de validacao e dados de teste
Xvalid, Xtest, Yvalid, Ytest = cross_validation.train_test_split(Xvalid,
Yvalid,
test_size = 0.5,
random_state = 0)
# inicializa classificador
clf = linear_model.SGDClassifier(loss = 'modified_huber',
penalty = 'l2',
alpha = 0.0001,
fit_intercept = True,
n_iter = 60,
shuffle = True,
random_state = None,
n_jobs = 4,
learning_rate = 'optimal',
eta0 = 0.0,
power_t = 0.5,
class_weight = None,
warm_start = False)
# treina classificador
start = 0
end = 710518
classes = list(set(Y))
while end <= 19894531:
print end
clf.partial_fit(Xtrain[start:end],
Ytrain[start:end],
classes)
start += 710519
end += 710519
# classifica
probs = clf.predict_proba(Xvalid)
# checa classificacoes
matches = 0
total = 0
for p, y in zip(probs, Yvalid):
print total
total += 1
yhats = []
for i in p.argsort()[-3:][::-1]:
yhats.append(clf.classes_[i])
if y in yhats:
matches += 1
print float(matches) / total
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment