Skip to content

Instantly share code, notes, and snippets.

@Ab1992ao
Created May 17, 2021 09:18
Show Gist options
  • Save Ab1992ao/59f408c5ed704899150e7ca31adc0f5c to your computer and use it in GitHub Desktop.
Save Ab1992ao/59f408c5ed704899150e7ca31adc0f5c to your computer and use it in GitHub Desktop.
generate clf data for multitask pipe
class MulticlassGenerator:
def __init__(self, data_tuple, batch_size=256):
self._data = data_tuple
self._idx = np.arange(len(data_tuple[-1]))
self.generator = self.generate_batch(batch_size)
def generate_batch(self, size):
while True:
px_ids = np.random.choice(self._idx, size, replace=False)
samples = [p[px_ids] for p in self._data[:-1]]
labels = self._data[-1][px_ids]
yield samples+[labels], [1]*size
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment