Skip to content

Instantly share code, notes, and snippets.

@orico
Created April 18, 2018 14:53
Show Gist options
  • Save orico/d78c1e8905b1b8ce171d834b61a4c74f to your computer and use it in GitHub Desktop.
Save orico/d78c1e8905b1b8ce171d834b61a4c74f to your computer and use it in GitHub Desktop.
AL-BaseSelectionFunction
class BaseSelectionFunction(object):
def __init__(self):
pass
def select(self):
pass
class RandomSelection(BaseSelectionFunction):
@staticmethod
def select(probas_val, initial_labeled_samples):
random_state = check_random_state(0)
selection = np.random.choice(probas_val.shape[0], initial_labeled_samples, replace=False)
# print('uniques chosen:',np.unique(selection).shape[0],'<= should be equal to:',initial_labeled_samples)
return selection
class EntropySelection(BaseSelectionFunction):
@staticmethod
def select(probas_val, initial_labeled_samples):
e = (-probas_val * np.log2(probas_val)).sum(axis=1)
selection = (np.argsort(e)[::-1])[:initial_labeled_samples]
return selection
class MarginSamplingSelection(BaseSelectionFunction):
@staticmethod
def select(probas_val, initial_labeled_samples):
rev = np.sort(probas_val, axis=1)[:, ::-1]
values = rev[:, 0] - rev[:, 1]
selection = np.argsort(values)[:initial_labeled_samples]
return selection
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment