Skip to content

Instantly share code, notes, and snippets.

@wassname
Created September 16, 2019 01:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save wassname/17cbfe0b68148d129a3ddaa227696496 to your computer and use it in GitHub Desktop.
Save wassname/17cbfe0b68148d129a3ddaa227696496 to your computer and use it in GitHub Desktop.
jaccard distance loss pytorch [draft]
#!/usr/bin/env python
# coding: utf-8
get_ipython().run_line_magic('pylab', 'inline')
import torch
def jaccard_distance_loss(y_true, y_pred, smooth=100):
"""
Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
= sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
The jaccard distance loss is usefull for unbalanced datasets. This has been
shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
gradient.
Ref: https://en.wikipedia.org/wiki/Jaccard_index
@url: https://gist.github.com/wassname/17cbfe0b68148d129a3ddaa227696496
@author: wassname
"""
intersection= (y_true * y_pred).abs().sum(dim=-1)
sum_ = torch.sum(y_true.abs() + y_pred.abs(), dim=-1)
jac = (intersection + smooth) / (sum_ - intersection + smooth)
return (1 - jac) * smooth
# Test and plot
y_pred = torch.from_numpy(np.array([np.arange(-10, 10+0.1, 0.1)]).T)
y_true = torch.from_numpy(np.zeros(y_pred.shape))
name='jaccard_distance_loss'
loss = jaccard_distance_loss(
y_true,y_pred
).numpy()
plt.title(name)
plt.plot(y_pred.numpy(),loss)
plt.xlabel('abs prediction error')
plt.ylabel('loss')
plt.show()
name='binary cross entropy'
loss = torch.nn.functional.binary_cross_entropy(
y_true,y_pred, reduction='none'
).mean(-1).numpy()
plt.title(name)
plt.plot(y_pred.numpy(),loss)
plt.xlabel('abs prediction error')
plt.ylabel('loss')
plt.show()
# Test
print("TYPE |Almost_right |half right |all_wrong")
y_true = torch.from_numpy(np.array([[0,0,1,0],[0,0,1,0],[0,0,1.,0.]]))
y_pred = torch.from_numpy(np.array([[0,0,0.9,0],[0,0,0.1,0],[1,1,0,1]]))
r1 = jaccard_distance_loss(
y_true,
y_pred,
).numpy()
print('jaccard_distance_loss',r1)
print('jaccard_distance_loss scaled',r1/r1.max())
assert r1[0]<r1[1]
assert r1[1]<r1[2]
r2 = torch.nn.functional.binary_cross_entropy(
y_true,
y_pred,
reduction='none'
).mean(-1).numpy()
print('binary_crossentropy',r2)
print('binary_crossentropy_scaled',r2/r2.max())
assert r2[0]<r2[1]
assert r2[1]<r2[2]
@wassname
Copy link
Author

image
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment