Skip to content

Instantly share code, notes, and snippets.

@Tcc0403
Created September 20, 2024 18:45
Show Gist options
  • Save Tcc0403/67aa7f8eaf536ae63f21f83405298047 to your computer and use it in GitHub Desktop.
Save Tcc0403/67aa7f8eaf536ae63f21f83405298047 to your computer and use it in GitHub Desktop.
import torch
from torch.nn import KLDivLoss
from liger_kernel.transformers.kl_div import LigerKLDIVLoss
B, T, V = 1, 4096, 32000
dtype, atol, rtol = torch.float32, 1e-8, 1e-6
torch.manual_seed(0)
torch_kldiv = KLDivLoss(reduction="batchmean", log_target=True)
target_kldiv = LigerKLDIVLoss(reduction="batchmean", log_target=True)
input = torch.randn(
B * T, V, device="cuda", dtype=dtype, requires_grad=True
).log_softmax(dim=-1)
x1 = input.detach().clone().requires_grad_(True)
x2 = input.detach().clone().requires_grad_(True)
with torch.no_grad():
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
output = torch_kldiv(x1, target)
output2 = target_kldiv(x2, target)
print(output, output2)
assert torch.allclose(output, output2, atol=atol, rtol=rtol)
output.backward()
output2.backward()
print(f"{x1.grad=}")
print(f"{x2.grad=}")
assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment