Last active
January 29, 2023 20:29
-
-
Save shi510/b97c044d75b386b9ee7a9e706837c2cf to your computer and use it in GitHub Desktop.
Tensorflow 2.0 Online Hard Example Mining (OHEM)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class HardExampleMiner(tf.keras.utils.Sequence): | |
def __init__(self, model, x, y, batch_size, map_fn=None, ratio=0.8): | |
self.x = np.array(x, dtype=np.float32) | |
self.y = np.array(y, dtype=np.float32) | |
self.batch_size = batch_size | |
self.model = model | |
self.ratio = ratio | |
self.num_of_batch = int(math.floor((len(self.x) / self.batch_size))) | |
self.hard_idxs = np.arange(self.num_of_batch * self.batch_size) | |
self.errors = np.empty((self.num_of_batch * self.batch_size)) | |
self.sample_x = np.empty((self.batch_size, self.x.shape[1])) | |
self.sample_y = np.empty((self.batch_size, self.y.shape[1])) | |
def __len__(self): | |
return int(self.num_of_batch * self.ratio) | |
def __getitem__(self, batch_id): | |
start = self.batch_size * batch_id | |
end = self.batch_size * (batch_id + 1) | |
for seq, idx in enumerate(self.hard_idxs[start:end]): | |
self.sample_x[seq,] = self.x[idx] | |
self.sample_y[seq,] = self.y[idx] | |
return (self.sample_x, self.sample_y) | |
def on_epoch_end(self): | |
for batch_id in range(self.num_of_batch): | |
sample_x = self._slice_batch(self.x, batch_id) | |
sample_y = self._slice_batch(self.y, batch_id) | |
outputs, _ = self.model.predict_on_batch(sample_x) | |
diff = np.abs(outputs - sample_y).reshape(-1) | |
self.errors[batch_id*self.batch_size:(batch_id+1) * | |
self.batch_size] = diff | |
self.hard_idxs = np.argsort(-self.errors) | |
def _slice_batch(self, x, id): | |
sample = x[self.batch_size * id: self.batch_size * (id + 1)] | |
return sample.reshape(self.batch_size, -1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class OHEMCallback(tf.keras.callbacks.Callback): | |
def __init__(self, generator): | |
super(OHEMCallback, self).__init__() | |
self.generator = generator | |
def on_epoch_begin(self, epoch, logs=None): | |
self.generator.on_epoch_end() | |
def make_dataset(x, y, batch): | |
ds = tf.data.Dataset.from_tensor_slices((x, y)) | |
ds = ds.repeat() | |
ds = ds.shuffle(size) | |
ds = ds.batch(batch) | |
ds = ds.prefetch(1024) | |
return ds | |
def train(x, y, lr=1e-4, batch=512, valid_x, valid_y): | |
# ... | |
# build your model. | |
# mode.compile() ... | |
valid_ds = make_dataset(valid_x, valid_y, batch) | |
x_gen = HardExampleMiner(model, x, y, batch) | |
ohem_callback = OHEMCallback(x_gen) | |
model.fit( | |
x=x_gen, | |
validation_data=valid_ds, | |
validation_steps=int(math.floor(len(valid_x) / batch)), | |
epochs=1000, | |
callbacks=[ohem_callback], | |
max_queue_size=256, | |
workers=8 | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In TF 2.0:
DO NOT USE Model.fit_generator(). It is very SLOW.
Instead of that, use fit().
But, Model.fit() does not call tf.keras.utils.Sequence's on_epoch_end().
So you have to implement custum callback for on_epoch_begin().
It works.