Skip to content

Instantly share code, notes, and snippets.

@nokados
Last active March 8, 2019 14:57
Show Gist options
  • Save nokados/c09eb54a6fad8007aa65cd0f8baafb6f to your computer and use it in GitHub Desktop.
Save nokados/c09eb54a6fad8007aa65cd0f8baafb6f to your computer and use it in GitHub Desktop.
Analogue of sklearn's train_test_split for multilabel classification with stratification and shuffling. And also under/over sampling to make distributions of class lengths more flat.
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
def parallel_shuffle(*arrays):
length = arrays[0].shape[0]
for arr in arrays:
assert arr.shape[0] == length
p = np.random.permutation(length)
return [arr[p] for arr in arrays]
def multi_strat_split(x_train, y_train, test_size=0.2, random_state=None):
# Sizes
test_freq = test_size
size = x_train.shape[0]
train_size = int((1-test_freq) * size)
test_size = size - train_size
# Shuffle before
y = np.array(y_train)
np.random.seed(random_state)
x, y = parallel_shuffle(x_train, y)
# Create resulting arrays
X_train_new = np.zeros((train_size, *x.shape[1:]))
Y_train_new = np.zeros((train_size, *y.shape[1:]))
X_test_new = np.zeros((test_size, *x.shape[1:]))
Y_test_new = np.zeros((test_size, *y.shape[1:]))
# Ordering classes by length
class_sizes = y.sum(axis=0)
class_indices = np.argsort(class_sizes)
# Choosing samples
test_index = 0
train_index = 0
used_indices = set()
for cls_id in class_indices:
cls_size = class_sizes[cls_id]
cls_train_size = int((1-test_freq) * cls_size)
cls_test_size = cls_size - cls_train_size
current_test_size = Y_test_new[:, cls_id].sum()
diff = cls_test_size - current_test_size
cls_samples_indices = np.argwhere(y[:, cls_id] == 1)
# Iterate to add test_samples
for ind in cls_samples_indices:
ind=ind[0]
if diff <= 0:
break
if test_index >= test_size:
break
if ind in used_indices:
continue
X_test_new[test_index] = x[ind]
Y_test_new[test_index] = y[ind]
test_index += 1
used_indices.add(ind)
diff -= 1
# iterate to add train_samples
for ind in cls_samples_indices:
ind=ind[0]
if train_index >= train_size:
break
if ind in used_indices:
continue
X_train_new[train_index] = x[ind]
Y_train_new[train_index] = y[ind]
train_index += 1
used_indices.add(ind)
assert train_index == train_size
if test_index < test_size:
unused_indices = set(range(x.shape[0])) - used_indices
for ind in unused_indices:
X_test_new[test_index] = x[ind]
Y_test_new[test_index] = y[ind]
test_index += 1
used_indices.add(ind)
assert test_index == test_size
test_parts = Y_test_new.sum(axis=0) / class_sizes
print('Min test_part: ', test_parts.min(), ' at index ', test_parts.argmin())
print('Max test_part: ', test_parts.max(), ' at index ', test_parts.argmax())
X_train_new, Y_train_new = parallel_shuffle(X_train_new, Y_train_new)
X_test_new, Y_test_new = parallel_shuffle(X_test_new, Y_test_new)
return X_train_new, X_test_new, Y_train_new, Y_test_new
def flat_sampling(x_train, y_train, max_quantile = 0.85, max_div_min=5, seed=None):
np.random.seed(seed)
x_train = np.array(x_train); y_train=np.array(y_train)
class_sizes = y_train.sum(axis=0)
print(f'BEFORE: Max size {class_sizes.max()}. Min size: {class_sizes.min()}. Total samples: {x_train.shape[0]}')
class_indices = np.argsort(class_sizes)
plt.bar(range(len(class_sizes)), class_sizes[class_indices])
X = np.zeros((0, *x_train.shape[1:]))
y = np.zeros((0, *y_train.shape[1:]))
def updateXy(indices):
nonlocal X
nonlocal y
if len(indices) == 0:
return
X = np.concatenate((X, x_train[indices]), axis=0)
y = np.concatenate((y, y_train[indices]), axis=0)
max_size = int(class_sizes[class_indices[int(len(class_sizes) * max_quantile)]])
min_size = max(1, max_size // max_div_min)
print(f'Expected AFTER: Max size {max_size}. Min size: {min_size}')
used_indices = set()
for cls_id in class_indices:
cls_samples_indices = np.argwhere(y_train[:, cls_id] == 1)[:,0]
actual_size = int(y[:, cls_id].sum())
unused_indices = np.array([ind for ind in cls_samples_indices if ind not in used_indices])
if class_sizes[cls_id] < min_size:
updateXy(cls_samples_indices)
add_num = max(0, int(min_size - class_sizes[cls_id] - actual_size))
additional_indices = np.random.choice(cls_samples_indices, add_num)
updateXy(additional_indices)
elif class_sizes[cls_id] > max_size or len(unused_indices) + actual_size > max_size:
if max_size <= actual_size:
continue
indices = np.random.choice(unused_indices, max_size - actual_size, replace=False)
updateXy(indices)
else:
updateXy(unused_indices)
used_indices |= set(cls_samples_indices)
assert X.shape[0] == y.shape[0]
class_sizes = y.sum(axis=0)
print(f'Actual AFTER: Max size {class_sizes.max()}. Min size: {class_sizes.min()}. Total samples: {X.shape[0]}')
plt.bar(range(len(class_sizes)), class_sizes[class_indices], alpha=0.5)
plt.show()
return parallel_shuffle(X, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment