Skip to content

Instantly share code, notes, and snippets.

@iantimmis
Created October 12, 2020 03:54
Show Gist options
  • Save iantimmis/f3cfa3953db4817c54c874b812e5e5f3 to your computer and use it in GitHub Desktop.
Save iantimmis/f3cfa3953db4817c54c874b812e5e5f3 to your computer and use it in GitHub Desktop.
Numerically stable softmax with cross entropy in numpy
import numpy as np
def naive_softmax(logits):
'''
Failure modes:
* If any entry is very large, exp overflows
* if all entries are very negative, all exps underflow
'''
exp_logits = np.exp(logits)
return exp_logits / np.sum(exp_logits)
def stable_softmax(logits):
'''
Mathematically equivalent to softmax.
'''
max_val = np.max(logits)
safe_exp_logits = np.exp(logits - max_val)
return safe_exp_logits / (max_val * np.sum(safe_exp_logits))
def naive_softmax_with_cross_entropy(logits, t):
'''
Softmax plugged into categorical cross entropy
'''
probs = naive_softmax(logits)
return -np.sum(t * probs)
def stable_softmax_with_cross_entropy(logits, t):
'''
Mathematically equivalent to softmax with cross entropy
'''
max_val = np.max(logits)
safe_logits = logits - max_val
safe_logsumexp = max_val + np.log(np.sum(np.exp(safe_logits)))
return safe_logsumexp - np.sum(t * safe_logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment