Created
December 2, 2014 15:32
-
-
Save yamaguchiyuto/03c0ed2acec02bb93da1 to your computer and use it in GitHub Desktop.
LP and LS experiments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import numpy as np | |
import random | |
from sklearn import datasets | |
from sklearn.semi_supervised import label_propagation | |
from sklearn import svm | |
from sklearn.grid_search import ParameterGrid | |
def score(estimator, X, y, parameters, validation_true_labels, test_true_labels, validation_set, test_set, X_validation_for_svm=None, X_test_for_svm=None): | |
""" Get precision on test data after grid-search on validation data """ | |
best_estimator = None | |
best_precision = -1 | |
for parameter in parameters: | |
clf = estimator(**parameter) | |
clf.fit(X,y) | |
if 'transduction_' in clf.__dict__: | |
predicted_labels = clf.transduction_[validation_set] | |
else: | |
predicted_labels = clf.predict(X_validation_for_svm) | |
precision = (validation_true_labels==predicted_labels).sum() / float(len(validation_true_labels)) | |
if precision > best_precision: | |
best_estimator = clf | |
best_precision = precision | |
if 'transduction_' in clf.__dict__: | |
predicted_labels = best_estimator.transduction_[test_set] | |
else: | |
predicted_labels = best_estimator.predict(X_test_for_svm) | |
test_precision = (test_true_labels==predicted_labels).sum() / float(len(test_true_labels)) | |
return test_precision | |
niter = int(sys.argv[1]) | |
labeled_ratio = float(sys.argv[2]) | |
shuffle_ratio = float(sys.argv[3]) | |
""" 1: LS, 2: LP, 3: SVM """ | |
parameters = [{'kernel':['rbf'], 'gamma':[0.00001,0.00005,0.0001,0.0005,0.001,0.005,0.01,0.05,0.1,0.5,1,5,10], 'alpha':[0.99]}, | |
{'kernel':['rbf'], 'gamma':[0.00001,0.00005,0.0001,0.0005,0.001,0.005,0.01,0.05,0.1,0.5,1,5,10], 'alpha':[1.0]}, | |
{'kernel':['rbf'], 'gamma':[0.00001,0.00005,0.0001,0.0005,0.001,0.005,0.01,0.05,0.1,0.5,1,5,10], 'C':[0.001,0.01,0.1,1.0,10.0,100.0,1000.0]}] | |
""" Data """ | |
digits = datasets.load_digits() | |
X = digits.data[:1000] | |
y = digits.target[:1000] | |
n_total_samples = len(y) | |
n_labeled_points = int(n_total_samples * labeled_ratio) | |
n_validation_points = (n_total_samples-n_labeled_points) / 2 | |
n_test_points = n_total_samples-n_labeled_points-n_validation_points | |
indices = np.arange(n_total_samples) | |
n_shuffled_points = int(n_labeled_points * shuffle_ratio) | |
random.seed(0) | |
precisions = [[],[],[]] | |
for i in range(0,niter): | |
random.shuffle(indices) | |
labeled_set = indices[:n_labeled_points] | |
unlabeled_set = indices[n_labeled_points:] | |
validation_set = indices[n_labeled_points:n_labeled_points+n_validation_points] | |
test_set = indices[n_labeled_points+n_validation_points:] | |
""" Data for SVM """ | |
X_train_for_svm = X[labeled_set] | |
X_validation_for_svm = X[validation_set] | |
X_test_for_svm = X[test_set] | |
y_train_for_svm = y[labeled_set] | |
y_validation_for_svm = y[validation_set] | |
y_test_for_svm = y[test_set] | |
""" Data for SSL """ | |
y_train = np.copy(y) | |
y_train[unlabeled_set] = -1 | |
validation_true_labels = y[validation_set] | |
test_true_labels = y[test_set] | |
""" Shuffle """ | |
y_train[range(0,n_shuffled_points)] = [[random.randint(0,9) for i in range(0,n_shuffled_points)]] | |
y_train_for_svm[range(0,n_shuffled_points)] = [[random.randint(0,9) for i in range(0,n_shuffled_points)]] | |
""" Get scores """ | |
precision_ls = score(label_propagation.LabelSpreading, X, y_train, list(ParameterGrid(parameters[0])), validation_true_labels, test_true_labels, validation_set, test_set) | |
precision_lp = score(label_propagation.LabelPropagation, X, y_train, list(ParameterGrid(parameters[1])), validation_true_labels, test_true_labels, validation_set, test_set) | |
precision_svm = score(svm.SVC, X_train_for_svm, y_train_for_svm, list(ParameterGrid(parameters[2])), y_validation_for_svm, y_test_for_svm, validation_set, test_set, X_validation_for_svm, X_test_for_svm) | |
precisions[0].append(precision_ls) | |
precisions[1].append(precision_lp) | |
precisions[2].append(precision_svm) | |
""" Output mean precision and std """ | |
for i in range(len(precisions)): | |
print np.array(precisions[i]).mean(), np.array(precisions[i]).std() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
def load_result(filepath): | |
means = [] | |
stds = [] | |
for line in open(filepath): | |
mean,std = map(float, line.rstrip().split(' ')) | |
means.append(mean) | |
stds.append(std) | |
return (means,stds) | |
lp_means = [] | |
ls_means = [] | |
svm_means = [] | |
lp_stds = [] | |
ls_stds = [] | |
svm_stds = [] | |
labeled_ratios = np.array([0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]) | |
for a in labeled_ratios: | |
filepath = 'result_' + str(a) + '_0.0' | |
means,stds = load_result(filepath) | |
ls_means.append(means[0]) | |
lp_means.append(means[1]) | |
svm_means.append(means[2]) | |
ls_stds.append(stds[0]) | |
lp_stds.append(stds[1]) | |
svm_stds.append(stds[2]) | |
width = 0.2 | |
loc = np.arange(len(labeled_ratios))+1 | |
plt.bar(loc, ls_means, yerr=ls_stds, ecolor='k', color='r', width=width, label='LS', align='center') | |
plt.bar(loc+width, lp_means, yerr=lp_stds, ecolor='k', color='g', width=width, label='LP', align='center') | |
plt.bar(loc+width*2, svm_means, yerr=svm_stds, ecolor='k', color='b', width=width, label='SVM', align='center') | |
plt.xlim([0.5,len(labeled_ratios)+1]) | |
plt.xticks(loc+width, labeled_ratios) | |
plt.ylabel('Precision') | |
plt.xlabel('Labeled ratio') | |
plt.legend(loc='lower right') | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def load_result(filepath): | |
means = [] | |
stds = [] | |
for line in open(filepath): | |
mean,std = map(float, line.rstrip().split(' ')) | |
means.append(mean) | |
stds.append(std) | |
return (means,stds) | |
lp_means = [] | |
ls_means = [] | |
svm_means = [] | |
lp_stds = [] | |
ls_stds = [] | |
svm_stds = [] | |
labeled_ratio = float(sys.argv[1]) | |
shuffled_ratios = np.array([0.0, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]) | |
for a in shuffled_ratios: | |
filepath = 'result_' + str(labeled_ratio) + '_' + str(a) | |
means,stds = load_result(filepath) | |
ls_means.append(means[0]) | |
lp_means.append(means[1]) | |
svm_means.append(means[2]) | |
ls_stds.append(stds[0]) | |
lp_stds.append(stds[1]) | |
svm_stds.append(stds[2]) | |
width = 0.2 | |
loc = np.arange(len(shuffled_ratios))+1 | |
plt.bar(loc, ls_means, yerr=ls_stds, ecolor='k', color='r', width=width, label='LS', align='center') | |
plt.bar(loc+width, lp_means, yerr=lp_stds, ecolor='k', color='g', width=width, label='LP', align='center') | |
plt.bar(loc+width*2, svm_means, yerr=svm_stds, ecolor='k', color='b', width=width, label='SVM', align='center') | |
plt.xlim([0.5,len(shuffled_ratios)+1]) | |
plt.xticks(loc+width, shuffled_ratios) | |
plt.legend(loc='lower right') | |
plt.ylabel('Precision') | |
plt.xlabel('Shuffled ratio') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
以下を実行するとブログ中の図が得られるよ!
[]内は全ての組み合わせを実行してね!