Skip to content

Instantly share code, notes, and snippets.

@shi510
Last active January 29, 2023 20:29
Show Gist options
  • Save shi510/b97c044d75b386b9ee7a9e706837c2cf to your computer and use it in GitHub Desktop.
Save shi510/b97c044d75b386b9ee7a9e706837c2cf to your computer and use it in GitHub Desktop.
Tensorflow 2.0 Online Hard Example Mining (OHEM)
class HardExampleMiner(tf.keras.utils.Sequence):
def __init__(self, model, x, y, batch_size, map_fn=None, ratio=0.8):
self.x = np.array(x, dtype=np.float32)
self.y = np.array(y, dtype=np.float32)
self.batch_size = batch_size
self.model = model
self.ratio = ratio
self.num_of_batch = int(math.floor((len(self.x) / self.batch_size)))
self.hard_idxs = np.arange(self.num_of_batch * self.batch_size)
self.errors = np.empty((self.num_of_batch * self.batch_size))
self.sample_x = np.empty((self.batch_size, self.x.shape[1]))
self.sample_y = np.empty((self.batch_size, self.y.shape[1]))
def __len__(self):
return int(self.num_of_batch * self.ratio)
def __getitem__(self, batch_id):
start = self.batch_size * batch_id
end = self.batch_size * (batch_id + 1)
for seq, idx in enumerate(self.hard_idxs[start:end]):
self.sample_x[seq,] = self.x[idx]
self.sample_y[seq,] = self.y[idx]
return (self.sample_x, self.sample_y)
def on_epoch_end(self):
for batch_id in range(self.num_of_batch):
sample_x = self._slice_batch(self.x, batch_id)
sample_y = self._slice_batch(self.y, batch_id)
outputs, _ = self.model.predict_on_batch(sample_x)
diff = np.abs(outputs - sample_y).reshape(-1)
self.errors[batch_id*self.batch_size:(batch_id+1) *
self.batch_size] = diff
self.hard_idxs = np.argsort(-self.errors)
def _slice_batch(self, x, id):
sample = x[self.batch_size * id: self.batch_size * (id + 1)]
return sample.reshape(self.batch_size, -1)
class OHEMCallback(tf.keras.callbacks.Callback):
def __init__(self, generator):
super(OHEMCallback, self).__init__()
self.generator = generator
def on_epoch_begin(self, epoch, logs=None):
self.generator.on_epoch_end()
def make_dataset(x, y, batch):
ds = tf.data.Dataset.from_tensor_slices((x, y))
ds = ds.repeat()
ds = ds.shuffle(size)
ds = ds.batch(batch)
ds = ds.prefetch(1024)
return ds
def train(x, y, lr=1e-4, batch=512, valid_x, valid_y):
# ...
# build your model.
# mode.compile() ...
valid_ds = make_dataset(valid_x, valid_y, batch)
x_gen = HardExampleMiner(model, x, y, batch)
ohem_callback = OHEMCallback(x_gen)
model.fit(
x=x_gen,
validation_data=valid_ds,
validation_steps=int(math.floor(len(valid_x) / batch)),
epochs=1000,
callbacks=[ohem_callback],
max_queue_size=256,
workers=8
)
@shi510
Copy link
Author

shi510 commented Dec 4, 2019

In TF 2.0:
DO NOT USE Model.fit_generator(). It is very SLOW.
Instead of that, use fit().
But, Model.fit() does not call tf.keras.utils.Sequence's on_epoch_end().
So you have to implement custum callback for on_epoch_begin().
It works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment