Created
April 18, 2018 14:53
-
-
Save orico/d78c1e8905b1b8ce171d834b61a4c74f to your computer and use it in GitHub Desktop.
AL-BaseSelectionFunction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BaseSelectionFunction(object): | |
def __init__(self): | |
pass | |
def select(self): | |
pass | |
class RandomSelection(BaseSelectionFunction): | |
@staticmethod | |
def select(probas_val, initial_labeled_samples): | |
random_state = check_random_state(0) | |
selection = np.random.choice(probas_val.shape[0], initial_labeled_samples, replace=False) | |
# print('uniques chosen:',np.unique(selection).shape[0],'<= should be equal to:',initial_labeled_samples) | |
return selection | |
class EntropySelection(BaseSelectionFunction): | |
@staticmethod | |
def select(probas_val, initial_labeled_samples): | |
e = (-probas_val * np.log2(probas_val)).sum(axis=1) | |
selection = (np.argsort(e)[::-1])[:initial_labeled_samples] | |
return selection | |
class MarginSamplingSelection(BaseSelectionFunction): | |
@staticmethod | |
def select(probas_val, initial_labeled_samples): | |
rev = np.sort(probas_val, axis=1)[:, ::-1] | |
values = rev[:, 0] - rev[:, 1] | |
selection = np.argsort(values)[:initial_labeled_samples] | |
return selection |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment