Last active
October 29, 2019 09:19
-
-
Save racinmat/a8eb2f727fcc4bd24745042eee5cd8d6 to your computer and use it in GitHub Desktop.
Example of multiple instance learning problem with fit_generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import tensorflow as tf | |
import numpy as np | |
from sklearn.preprocessing import MultiLabelBinarizer | |
def highest_idx_leq(arr_cumsum, thr): | |
""" Returns index of highest value less than or equal to thr """ | |
return np.searchsorted(arr_cumsum, thr, side='right') - 1 | |
# basically almost same batch_creation_train, but as Keras generator | |
class BatchGenerator(tf.keras.utils.Sequence): | |
def __init__(self): | |
super().__init__() | |
self.indices = df.index | |
self.cur_index = 0 | |
self.i_to_batch_end = dict() | |
self.batches_num = 0 | |
self.on_epoch_end() | |
def batches_num_true(self): | |
list_cumsum = df.loc[self.indices]['list_len'].cumsum().reset_index(drop=True).to_numpy() | |
list_cumsum_max = list_cumsum.max() | |
index_start, batches_num = 0, 0 | |
self.i_to_batch_end[-1] = -1 | |
while index_start < list_cumsum_max: | |
batch_end = highest_idx_leq(list_cumsum, index_start + batch_size) | |
index_start = list_cumsum[batch_end] | |
self.i_to_batch_end[batches_num] = batch_end | |
batches_num += 1 | |
return batches_num | |
def __len__(self): | |
return self.batches_num | |
def __getitem__(self, index): | |
batch = df.loc[self.indices[self.i_to_batch_end[index - 1] + 1:self.i_to_batch_end[index] + 1]].copy() | |
batch['id'] = np.arange(0, len(batch)) | |
feats = np.concatenate(batch['list'].apply(list_enc.transform).values) | |
segments = batch['id'].repeat(batch['list_len']).values | |
labels = batch['label'].values.astype(np.int32) | |
return (feats, segments), labels | |
def on_epoch_end(self): | |
self.indices = np.random.permutation(df.index) | |
self.batches_num = self.batches_num_true() | |
self.cur_index = 0 | |
class SegmentedMean(tf.keras.layers.Layer): | |
def __init__(self, *args, **kwargs): | |
super(SegmentedMean, self).__init__(*args, **kwargs) | |
def call(self, inputs, **kwargs): | |
features, segments = inputs | |
return tf.math.segment_mean(features, segments) | |
df = pd.DataFrame({'list': [ | |
[[1, 2, 3, 4], [2, 3, 6]], | |
[[1, 2], [1, 5, 8]], | |
[[6, 7, 8], [2, 4, 10], [1, 6], [5]], | |
[[3, 4, 6, 8], [1, 8], [2], [7]], | |
[[3, 6, 8]], | |
[[4, 2, 2], [8, 1, 5, 6]], | |
[[3, 2, 1], [9, 8], [4]], | |
], | |
'label': [0, 0, 1, 1, 1, 0, 0]}) | |
df['list_len'] = df['list'].apply(len) | |
batch_size = 8 | |
# list of all unique numbers used | |
nums_used = list(set([i for j in df['list'].apply(lambda k: [i for j in k for i in j]) for i in j])) | |
list_enc = MultiLabelBinarizer().fit([[i] for i in nums_used]) | |
class MyCallback(tf.keras.callbacks.Callback): | |
def on_train_batch_begin(self, batch, logs=None): | |
super().on_train_batch_begin(batch, logs) | |
print(f'on_train_batch_end called, logs: {logs}') | |
def on_train_batch_end(self, batch, logs=None): | |
super().on_train_batch_end(batch, logs) | |
print(f'on_train_batch_end called, logs: {logs}') | |
generator = BatchGenerator() | |
settings = {'k': 40, 'steps': 10} | |
feats_len = len(nums_used) | |
inputs = tf.keras.Input(shape=(feats_len,), name='features') | |
segments = tf.keras.Input(shape=(), name='segments', dtype=tf.int32) | |
x = tf.keras.layers.Dense(settings['k'], activation=tf.nn.relu)(inputs) | |
x = tf.keras.layers.Dense(settings['k'])(x) | |
x = SegmentedMean()((x, segments)) | |
x = tf.keras.layers.Dense(settings['k'], activation=tf.nn.relu)(x) | |
logits = tf.keras.layers.Dense(2, name='output_logits')(x) | |
probs = tf.keras.layers.Softmax()(logits) | |
model = tf.keras.Model(inputs=(inputs, segments), outputs=(logits, probs), name='mil_model') | |
optimizer = tf.keras.optimizers.Adam() | |
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
model.compile(optimizer=optimizer, loss={'output_logits': loss}) | |
model.fit_generator(generator, epochs=5, workers=4, callbacks=[MyCallback()]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment