Last active
June 4, 2020 21:16
-
-
Save Mehdi-Amine/6b3b6567e8178b409dfd2906226d8ffc to your computer and use it in GitHub Desktop.
Softmax differentiation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
#----------- Implementing the math -----------# | |
def softmax(z): | |
return z.exp() / z.exp().sum(axis=1, keepdim=True) | |
def softmax_prime(z): | |
sm = softmax(z).squeeze() | |
sm_size = sm.shape[0] | |
sm_ps = [] | |
for i, sm_i in enumerate(sm): | |
for j, sm_j in enumerate(sm): | |
# First case: i and j are equal: | |
if(i==j): | |
# Differentiating the softmax of a neuron w.r.t to itself | |
sm_p = sm_i * (1 - sm_i) | |
sm_ps.append(sm_p) | |
# Second case: i and j are not equal: | |
else: | |
# Differentiating the softmax of a neuron w.r.t to another neuron | |
sm_p = -sm_i * sm_j | |
sm_ps.append(sm_p) | |
sm_ps = torch.tensor(sm_ps).view(sm_size, sm_size) | |
return sm_ps | |
z = torch.tensor([[4., 2.]], requires_grad=True) | |
sm_p = softmax_prime(z) | |
#----------- Using Pytorch autograd -----------# | |
torch_sm = F.softmax(z, dim=1) | |
# to extract the first row in the jacobian matrix, use [[1., 0]] | |
# retain_graph=True because we re-use backward() for the second row | |
torch_sm.backward(torch.tensor([[1.,0.]]), retain_graph=True) | |
r1 = z.grad | |
z.grad = torch.zeros_like(z) | |
# to extract the second row in the jacobian matrix, use [[0., 1.]] | |
torch_sm.backward(torch.tensor([[0.,1.]])) | |
r2 = z.grad | |
torch_sm_p = torch.cat((r1,r2)) | |
#----------- Comparing outputs -----------# | |
print(f"Pytorch Softmax': \n{torch_sm_p} \nOur Softmax': \n{sm_p}") | |
''' | |
Out: | |
Pytorch Softmax': | |
tensor([[ 0.1050, -0.1050], | |
[-0.1050, 0.1050]]) | |
Our Softmax': | |
tensor([[ 0.1050, -0.1050], | |
[-0.1050, 0.1050]]) | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment