Last active
February 26, 2022 09:42
-
-
Save mashrurmorshed/9a2cb4732a580e91d6e2da6da71356f0 to your computer and use it in GitHub Desktop.
Knowledge Distillation Loss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class KDLoss(nn.Module): | |
"""Knowledge Distillation loss.""" | |
def __init__(self, dim: int = -1, scale_T: bool = True) -> None: | |
"""Initializer for KDLoss. | |
Args: | |
dim (int, optional): Dimension across which to apply loss. Defaults to -1. | |
scale_T (bool, optional): Whether to scale kldiv by T^2. Defaults to True. | |
""" | |
super().__init__() | |
self.dim = dim | |
self.scale_T = scale_T | |
def forward(self, pred: torch.Tensor, target: torch.Tensor, teacher_pred: torch.Tensor, T: float, alpha: float, beta: float = None) -> torch.Tensor: | |
"""Forward method for KDLoss. | |
Args: | |
pred (torch.Tensor): Predictions of student model. Tensor of shape (batch, num_classes). | |
target (torch.Tensor): Labels. LongTensor of shape (batch,), containing class integers like [1, 2, 3, ...]. | |
teacher_pred (torch.Tensor): Predictions of teacher model. Tensor of shape (batch, num_classes). | |
T (float): Temperature value for evaluating softmax. | |
alpha (float): Weight for kldiv. | |
beta (float, optional): Weight for crossentropy. If not provided (beta=None), will use beta = 1 - alpha. Defaults to None. | |
Returns: | |
torch.Tensor: Loss value. | |
""" | |
assert T >= 1.0, f"Expected temperature greater or equal to 1.0, but got {T}." | |
if beta == None: | |
assert alpha < 1.0, f"For weighted average (beta=None), alpha must be less than 1.0, but got {alpha}." | |
beta = 1.0 - alpha | |
if self.scale_T: | |
alpha = alpha * T * T | |
pred_log_probs = F.log_softmax(pred / T, dim=self.dim) | |
teacher_pred_log_probs = F.log_softmax(teacher_pred / T, dim=self.dim) | |
kldiv = F.kl_div(pred_log_probs, teacher_pred_log_probs, log_target=True) | |
crossentropy = F.cross_entropy(pred, target) | |
return alpha * kldiv + beta * crossentropy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment