Skip to content

Instantly share code, notes, and snippets.

@VictorSanh
Last active August 15, 2023 02:19
  • Star 11 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save VictorSanh/db90644aae5094654db87f9769c2e5ae to your computer and use it in GitHub Desktop.
Knowledge Distilation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer
KD_loss = nn.KLDivLoss(reduction='batchmean')
def kd_step(teacher: nn.Module,
student: nn.Module,
temperature: float,
inputs: torch.tensor,
optimizer: Optimizer):
teacher.eval()
student.train()
with torch.no_grad():
logits_t = teacher(inputs=inputs)
logits_s = student(inputs=inputs)
loss = KD_loss(input=F.log_softmax(logits_s/temperature, dim=-1),
target=F.softmax(logits_t/temperature, dim=-1))
loss.backward()
optimizer.step()
optimizer.zero_grad()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment