Skip to content

Instantly share code, notes, and snippets.

@aurotripathy
Last active February 4, 2024 06:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aurotripathy/e6a2166659390405de8e6325a67c952d to your computer and use it in GitHub Desktop.
Save aurotripathy/e6a2166659390405de8e6325a67c952d to your computer and use it in GitHub Desktop.
# Lifted verbatim from https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html
# Only for illustrating the distillation training loop as a code fragment.
for epoch in range(epochs):
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# Forward pass with the teacher model.
# Do NOT save gradients here as we do not change the teacher's weights.
with torch.no_grad():
teacher_logits = teacher(inputs)
# Forward pass with the student model
student_logits = student(inputs)
#Soften the student logits by applying softmax first and log() second
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)
# Calculate the soft targets loss (KLDivLoss).
# Its scaled by T**2 as suggested by the authors of the paper,
# "Distilling the knowledge in a neural network"
soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2)
# Calculate the true label loss, nn.CrossEntropyLoss()
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment