This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trainset_size = 60000 # ie., testset_size = 10000 | |
def download(): | |
mnist = fetch_mldata(‘MNIST original’) | |
X = mnist.data.astype(‘float64’) | |
y = mnist.target | |
print (‘MNIST:’, X.shape, y.shape) | |
return (X, y) | |
def split(train_size): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BaseModel(object): | |
def __init__(self): | |
pass | |
def fit_predict(self): | |
pass | |
class SvmModel(BaseModel): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TrainModel: | |
def __init__(self, model_object): | |
self.accuracies = [] | |
self.model_object = model_object() | |
def print_model_type(self): | |
print (self.model_object.model_type) | |
# we train normally and get probabilities for the validation set. i.e., we use the probabilities to select the most uncertain samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BaseSelectionFunction(object): | |
def __init__(self): | |
pass | |
def select(self): | |
pass | |
class RandomSelection(BaseSelectionFunction): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Normalize(object): | |
def normalize(self, X_train, X_val, X_test): | |
self.scaler = MinMaxScaler() | |
X_train = self.scaler.fit_transform(X_train) | |
X_val = self.scaler.transform(X_val) | |
X_test = self.scaler.transform(X_test) | |
return (X_train, X_val, X_test) | |
def inverse(self, X_train, X_val, X_test): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_k_random_samples(initial_labeled_samples, X_train_full, | |
y_train_full): | |
random_state = check_random_state(0) | |
permutation = np.random.choice(trainset_size, | |
initial_labeled_samples, | |
replace=False) | |
print () | |
print ('initial random chosen samples', permutation.shape), | |
# permutation) | |
X_train = X_train_full[permutation] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TheAlgorithm(object): | |
accuracies = [] | |
def __init__(self, initial_labeled_samples, model_object, selection_function): | |
self.initial_labeled_samples = initial_labeled_samples | |
self.model_object = model_object | |
self.sample_selection_function = selection_function | |
def run(self, X_train_full, y_train_full, X_test, y_test): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(X, y) = download() | |
(X_train_full, y_train_full, X_test, y_test) = split(trainset_size) | |
print ('train:', X_train_full.shape, y_train_full.shape) | |
print ('test :', X_test.shape, y_test.shape) | |
classes = len(np.unique(y)) | |
print ('unique classes', classes) | |
def pickle_save(fname, data): | |
filehandler = open(fname,"wb") | |
pickle.dump(data,filehandler) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def performance_plot(fully_supervised_accuracy, dic, models, selection_functions, Ks, repeats): | |
fig, ax = plt.subplots() | |
ax.plot([0,500],[fully_supervised_accuracy, fully_supervised_accuracy],label = 'algorithm-upper-bound') | |
for model_object in models: | |
for selection_function in selection_functions: | |
for idx, k in enumerate(Ks): | |
x = np.arange(float(Ks[idx]), 500 + float(Ks[idx]), float(Ks[idx])) | |
Sum = np.array(dic[model_object][selection_function][k][0]) | |
for i in range(1, repeats): | |
Sum = Sum + np.array(dic[model_object][selection_function][k][i]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
print('So which is the best sample selection function? margin sampling is the winner!') | |
performance_plot(random_forest_upper_bound, d, ['RfModel'], selection_functions_str , Ks_str, 1) | |
print() | |
print('So which is the best k? k=10 is the winner') | |
performance_plot(random_forest_upper_bound, d, ['RfModel'] , ['MarginSamplingSelection'], Ks_str, 1) |
OlderNewer