Skip to content

Instantly share code, notes, and snippets.

@zhen8838
Created April 4, 2021 13:38
Show Gist options
  • Save zhen8838/523250cd88a035f5ddc6cc833a388d9b to your computer and use it in GitHub Desktop.
Save zhen8838/523250cd88a035f5ddc6cc833a388d9b to your computer and use it in GitHub Desktop.
def focal_sigmoid_cross_entropy_with_logits(labels: tf.Tensor,
logits: tf.Tensor,
gamma: float = 2.0,
alpha: float = 0.25):
pred_sigmoid = tf.nn.sigmoid(logits)
pt = (1 - pred_sigmoid) * labels + pred_sigmoid * (1 - labels)
focal_weight = (alpha * labels + (1 - alpha) * (1 - labels)) * tf.math.pow(pt, gamma)
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels, logits) * focal_weight
return loss
def dev_focal_loss_val():
import tensorflow as tf
import matplotlib.pyplot as plt
k = tf.keras
kl = tf.keras.layers
K = tf.keras.backend
logits = tf.range(-20, 20, 0.01)
labels = tf.ones_like(logits)
gamma=2
alpha=1
foloss= focal_sigmoid_cross_entropy_with_logits(labels,logits,gamma=gamma,alpha=alpha)
bceloss= tf.nn.sigmoid_cross_entropy_with_logits(labels,logits)
plt.plot(tf.nn.sigmoid(logits).numpy(),foloss.numpy(),label='focal loss')
plt.plot(tf.nn.sigmoid(logits).numpy(),bceloss.numpy(),label='bceloss')
plt.hlines(0.2,0,1,colors='r', linestyles='--',label='0.2')
plt.legend()
plt.xlabel('confidence')
plt.ylabel('loss value')
plt.title(f'focal loss gamma : {gamma} alpha : {alpha}')
plt.ylim((0,5))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment