Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active May 3, 2019 09:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomwolf/9a66fd1d9210dabf9e33eb30ff553a09 to your computer and use it in GitHub Desktop.
Save thomwolf/9a66fd1d9210dabf9e33eb30ff553a09 to your computer and use it in GitHub Desktop.
Compute multi-task loss
# Forward pass
lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids)
# Total loss as a weighted sum
lm_coef = 2.0
mc_coef = 1.0
total_loss = lm_loss * lm_coef + mc_loss * mc_coef
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment