Skip to content

Instantly share code, notes, and snippets.

@AruniRC
Created September 13, 2018 02:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AruniRC/f4ca0ea6b72bc1bf0b4060dc564ae712 to your computer and use it in GitHub Desktop.
Save AruniRC/f4ca0ea6b72bc1bf0b4060dc564ae712 to your computer and use it in GitHub Desktop.
Pytorch distillation soft targets
if self.distill:
soft_target = Variable(data[2].cuda())
distill_loss = torch.mean(torch.sum(- nn.Softmax(dim=1)(soft_target/self.T) * nn.LogSoftmax(dim=1)(out_data/self.T), 1))
loss += self.lbda*distill_loss
self.writer.add_scalar('train/distill_loss', distill_loss, i_acc+i+1)
@AruniRC
Copy link
Author

AruniRC commented Sep 13, 2018

(code snippet courtesy Jong-Chyi)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment