Skip to content

Instantly share code, notes, and snippets.

@davidwhogg
Created June 6, 2016 18:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save davidwhogg/5050912cb7f2af8b99be0472e5b71f58 to your computer and use it in GitHub Desktop.
Save davidwhogg/5050912cb7f2af8b99be0472e5b71f58 to your computer and use it in GitHub Desktop.
import numpy as np
def hoggsumexp(qns, dqn_dams, diag=False):
"""
# purpose:
- Computes L = log(sum(exp(qns, axis=-1))).
- Also computes its M-dimensional gradient components dL / da_m.
# input
- qns: ndarray of shape [n1, n2, n3, ..., nD, N]
- dqn_dams: ndarray of shape [n1, n2, n3, ..., nD, N, M]
- diag: if True, then dqn_dams.shape == dqn_dams.shape and [read the source]
# output
- L: ndarray of shape [n1, n2, n3, ..., nD]
- dL_dams: ndarray of shape [n1, n2, n3, ..., nD, M]
# issues
- Not exhaustively tested.
"""
axis = len(qns.shape) - 1
if diag:
assert qns.shape == dqn_dams.shape
Q = np.max(qns)
expqns = np.exp(qns - Q)
expL = np.sum(expqns, axis=axis)
if diag:
numerator = expqns * dqn_dams
else:
numerator = np.sum(np.expand_dims(expqns, axis + 1) * dqn_dams, axis=axis)
return np.log(expL) + Q, numerator / np.expand_dims(expL, axis)
@davidwhogg
Copy link
Author

Calling @dfm : Can you look at this and tell me if I am full of shit?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment