Skip to content

Instantly share code, notes, and snippets.

@VictorSanh VictorSanh/kd.py
Last active Jan 16, 2020

Embed
What would you like to do?
Knowledge Distilation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer
KD_loss = nn.KLDivLoss(reduction='batchmean')
def kd_step(teacher: nn.Module,
student: nn.Module,
temperature: float,
inputs: torch.tensor,
optimizer: Optimizer):
teacher.eval()
student.train()
with torch.no_grad():
logits_t = teacher(inputs=inputs)
logits_s = student(inputs=inputs)
loss = KD_loss(input=F.log_softmax(logits_s/temperature, dim=-1),
target=F.softmax(logits_t/temperature, dim=-1))
loss.backward()
optimizer.step()
optimizer.zero_grad()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.