Skip to content

Instantly share code, notes, and snippets.

@YannDubs
Created March 9, 2019 23:05
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YannDubs/3550259636987a7b460a200efbd6acf3 to your computer and use it in GitHub Desktop.
Save YannDubs/3550259636987a7b460a200efbd6acf3 to your computer and use it in GitHub Desktop.
Stratify sampling using numpy
def stratify_sampling(x, n_samples, stratify):
"""Perform stratify sampling of a tensor.
parameters
----------
x: np.ndarray or torch.Tensor
Array to sample from. Sampels from first dimension.
n_samples: int
Number of samples to sample
stratify: tuple of int
Size of each subgroup. Note that the sum of all the sizes
need to be equal to `x.shape[']`.
"""
n_total = x.shape[0]
assert sum(stratify) == n_total
n_strat_samples = [int(i*n_samples/n_total) for i in stratify]
cum_n_samples = np.cumsum([0]+list(stratify))
sampled_idcs = []
for i, n_strat_sample in enumerate(n_strat_samples):
sampled_idcs.append(np.random.choice(range(cum_n_samples[i], cum_n_samples[i+1]),
replace=False,
size=n_strat_sample))
# might not be correct number of samples due to rounding
n_current_samples = sum(n_strat_samples)
if n_current_samples < n_samples:
delta_n_samples = n_samples - n_current_samples
# might actually resample same as before, but it's only for a few
sampled_idcs.append(np.random.choice(range(n_total), replace=False, size=delta_n_samples))
samples = x[np.concatenate(sampled_idcs), ...]
return samples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment