Skip to content

Instantly share code, notes, and snippets.

@mrdrozdov
Last active March 8, 2019 16:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrdrozdov/547ba246d13787263cb644e2c081fbef to your computer and use it in GitHub Desktop.
Save mrdrozdov/547ba246d13787263cb644e2c081fbef to your computer and use it in GitHub Desktop.
history.py
def override_loss_hook(ner_loss):
old_loss_hook = ner_loss.loss_hook
history = dict(preds=[])
def loss_hook(self, pred, target):
history['preds'].append(pred)
return {}
ner_loss.loss_hook = types.MethodType(loss_hook, ner_loss)
def reset_func():
ner_loss.loss_hook = old_loss_hook
return history, reset_func
if __name__ == '__main__':
ner_loss = NERLoss()
history, reset_loss_hook = override_loss_hook(ner_loss)
for batch in batches:
ner_loss(batch) # this will also append the pred to history
aggregate(history)
reset_loss_hook()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment