Last active
February 4, 2024 06:13
-
-
Save aurotripathy/e6a2166659390405de8e6325a67c952d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Lifted verbatim from https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html | |
# Only for illustrating the distillation training loop as a code fragment. | |
for epoch in range(epochs): | |
running_loss = 0.0 | |
for inputs, labels in train_loader: | |
inputs, labels = inputs.to(device), labels.to(device) | |
optimizer.zero_grad() | |
# Forward pass with the teacher model. | |
# Do NOT save gradients here as we do not change the teacher's weights. | |
with torch.no_grad(): | |
teacher_logits = teacher(inputs) | |
# Forward pass with the student model | |
student_logits = student(inputs) | |
#Soften the student logits by applying softmax first and log() second | |
soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1) | |
soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1) | |
# Calculate the soft targets loss (KLDivLoss). | |
# Its scaled by T**2 as suggested by the authors of the paper, | |
# "Distilling the knowledge in a neural network" | |
soft_targets_loss = -torch.sum(soft_targets * soft_prob) / soft_prob.size()[0] * (T**2) | |
# Calculate the true label loss, nn.CrossEntropyLoss() | |
label_loss = ce_loss(student_logits, labels) | |
# Weighted sum of the two losses | |
loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment