Skip to content

Instantly share code, notes, and snippets.

@SuperShinyEyes
Created January 9, 2021 13:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SuperShinyEyes/98a3c249add3fdb3f3e0e293a4ddd27e to your computer and use it in GitHub Desktop.
Save SuperShinyEyes/98a3c249add3fdb3f3e0e293a4ddd27e to your computer and use it in GitHub Desktop.
'''
https://youtu.be/tRsSi_sqXjI
'''
def entropy(ps: List[float]) -> float:
return np.sum([-np.log2(p) * p for p in ps])
def cross_entropy(p_trues: List[float], p_preds: List[float]) -> float:
return np.sum([-np.log2(q) * p for p, q in zip(p_trues, p_preds)])
def kl_divergence(p_trues: List[float], p_preds: List[float]) -> float:
return cross_entropy(p_trues, p_preds) - entropy(p_trues)
ps = np.array([1, 1, 4, 4, 10, 10, 35, 35]) / 100
qs = np.array([25, 25, 12.5, 12.5, 6.25, 6.25, 3.125, 3.125]) / 100
print(f'{entropy(ps)=:.2f}')
print(f'{cross_entropy(p_trues=ps, p_preds=qs)=:.2f}')
print(f'{kl_divergence(p_trues=ps, p_preds=qs)=:.2f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment