Skip to content

Instantly share code, notes, and snippets.

@shhong
Last active May 19, 2023 11:45
Show Gist options
  • Star 10 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save shhong/1021654 to your computer and use it in GitHub Desktop.
Save shhong/1021654 to your computer and use it in GitHub Desktop.
NSB entropy in python
#!/usr/bin/env python
"""
nsb_entropy.py
June, 2011 written by Sungho Hong, Computational Neuroscience Unit, Okinawa Institute of
Science and Technology
May, 2019 updated to Python 3 by Charlie Strauss, Los Alamos National Lab
This script is a python version of Mathematica functions by Christian Mendl
implementing the Nemenman-Shafee-Bialek (NSB) estimator of entropy. For the
details of the method, check out the references below.
It depends on mpmath and numpy package.
References:
http://christian.mendl.net/pages/software.html
http://nsb-entropy.sourceforge.net
Ilya Nemenman, Fariel Shafee, and William Bialek. Entropy and Inference,
Revisited. arXiv:physics/0108025
Ilya Nemenman, William Bialek, and Rob de Ruyter van Steveninck. Entropy and
information in neural spike trains: Progress on the sampling problem. Physical
Review E 69, 056111 (2004)
"""
from mpmath import psi, rf, power, quadgl, mp, memoize
import numpy as np
DPS = 40
psi = memoize(psi)
rf = memoize(rf)
power = memoize(power)
def make_nxkx(n, K):
"""
Return the histogram of the input histogram n assuming that the number of
all bins is K.
>>> from numpy import array
>>> nTest = array([4, 2, 3, 0, 2, 4, 0, 0, 2])
>>> make_nxkx(nTest, 9) == {0: 3, 2: 3, 3: 1, 4: 2}
True
"""
nxkx = {}
nn = n[n > 0]
unn = np.unique(nn)
for x in unn:
nxkx[x] = (nn == x).sum()
if K > nn.size:
nxkx[0] = K - nn.size
return nxkx
def _xi(beta, K):
return psi(0, K * beta + 1) - psi(0, beta + 1)
def _dxi(beta, K):
return K * psi(1, K * beta + 1) - psi(1, beta + 1)
def _S1i(x, nxkx, beta, N, kappa):
return (
nxkx[x]
* (x + beta)
/ (N + kappa)
* (psi(0, x + beta + 1) - psi(0, N + kappa + 1))
)
def _S1(beta, nxkx, N, K):
kappa = beta * K
rx = np.array([_S1i(x, nxkx, beta, N, kappa) for x in nxkx])
return -rx.sum()
def _rhoi(x, nxkx, beta):
return power(rf(beta, np.double(x)), nxkx[x])
def _rho(beta, nxkx, N, K):
kappa = beta * K
rx = np.array([_rhoi(x, nxkx, beta) for x in nxkx])
return rx.prod() / rf(kappa, np.double(N))
def _Si(w, nxkx, N, K):
sbeta = w / (1 - w)
beta = sbeta * sbeta
return (
_rho(beta, nxkx, N, K)
* _S1(beta, nxkx, N, K)
* _dxi(beta, K)
* 2
* sbeta
/ (1 - w)
/ (1 - w)
)
def _measure(w, nxkx, N, K):
sbeta = w / (1 - w)
beta = sbeta * sbeta
return _rho(beta, nxkx, N, K) * _dxi(beta, K) * 2 * sbeta / (1 - w) / (1 - w)
def S(nxkx, N, K):
"""
Return the estimated entropy. nxkx is the histogram of the input histogram
constructed by make_nxkx. N is the total number of samples, and K is the
degree of freedom.
>>> from numpy import array
>>> nTest = array([4, 2, 3, 0, 2, 4, 0, 0, 2])
>>> K = 9 # which is actually equal to nTest.size.
>>> S(make_nxkx(nTest, K), nTest.sum(), K)
1.940646728502697166844598206814653716492
"""
mp.dps = DPS
mp.pretty = True
f = lambda w: _Si(w, nxkx, N, K)
g = lambda w: _measure(w, nxkx, N, K)
return quadgl(f, [0, 1], maxdegree=20) / quadgl(g, [0, 1], maxdegree=20)
def _S2i_diag(x, nxkx, beta, N, kappa):
xbeta = x + beta
Nkappa2 = N + kappa + 2
psNK2 = psi(0, Nkappa2)
ps1NK2 = psi(1, Nkappa2)
s1 = (psi(0, xbeta + 2) - psNK2) ** 2 + psi(1, xbeta + 2) - ps1NK2
s1 *= nxkx[x] * xbeta * (xbeta + 1)
s2 = (psi(0, xbeta + 1) - psNK2) ** 2 - ps1NK2
s2 *= nxkx[x] * (nxkx[x] - 1) * xbeta * xbeta
return s1 + s2
def _S2i_nondiag(x1, x2, nxkx, beta, N, kappa):
psNK2 = psi(0, N + kappa + 2)
ps1NK2 = psi(1, N + kappa + 2)
x1beta = x1 + beta
x2beta = x2 + beta
s1 = (psi(0, x1beta + 1) - psNK2) * (psi(0, x2beta + 1) - psNK2) - ps1NK2
s1 = s1 * nxkx[x1] * nxkx[x2] * x1beta * x2beta
return s1
def _S2(beta, nxkx, N, K):
kappa = beta * K
Nkappa = N + kappa
nx = np.array(list(nxkx.keys()))
Nnx = nx.size
dsum = 0.0
for x in nx:
dsum = dsum + _S2i_diag(x, nxkx, beta, N, kappa)
ndsum = 0.0
for i in range(Nnx - 1):
x1 = nx[i]
for j in range(i + 1, Nnx):
x2 = nx[j]
if x1 == x2:
ndsum = ndsum + _S2i_diag(x1, nxkx, beta, N, kappa)
else:
ndsum = ndsum + _S2i_nondiag(x1, x2, nxkx, beta, N, kappa)
return (dsum + 2 * ndsum) / (Nkappa) / (Nkappa + 1)
def _dSi(w, nxkx, N, K):
sbeta = w / (1 - w)
beta = sbeta * sbeta
return (
_rho(beta, nxkx, N, K)
* _S2(beta, nxkx, N, K)
* _dxi(beta, K)
* 2
* sbeta
/ (1 - w)
/ (1 - w)
)
def dS(nxkx, N, K):
"""
Return the mean squared flucuation of the entropy.
>>> from numpy import array, sqrt
>>> nTest = np.array([4, 2, 3, 0, 2, 4, 0, 0, 2])
>>> K = 9 # which is actually equal to nTest.size.
>>> nxkx = make_nxkx(nTest, K)
>>> s = S(nxkx, nTest.sum(), K)
>>> ds = dS(nxkx, nTest.sum(), K)
>>> ds
3.790453283682443237187782792251041316212
>>> sqrt(ds-s**2) # the standard deviation for the estimated entropy.
0.1560242251518078487118059349690693094484
"""
mp.dps = DPS
mp.pretty = True
f = lambda w: _dSi(w, nxkx, N, K)
g = lambda w: _measure(w, nxkx, N, K)
return quadgl(f, [0, 1], maxdegree=20) / quadgl(g, [0, 1], maxdegree=20)
if __name__ == "__main__":
import doctest
doctest.testmod()
@shhong
Copy link
Author

shhong commented Jun 25, 2011

  • Increased DPS to ensure correct evaluation in some cases.
  • Functions are memoized for speedup.

@cems2
Copy link

cems2 commented May 6, 2019

Hello, the code has errors in python3. I will post the revised working code below. I am also posting the error output here in this comment so that people searching for this error can find it.

**********************************************************************
File "__main__", line 167, in __main__.dS
Failed example:
    ds = dS(nxkx, nTest.sum(), K)
Exception raised:
    Traceback (most recent call last):
      File "/Users/cex/anaconda3/envs/jupyter2/lib/python3.7/doctest.py", line 1329, in __run
        compileflags, 1), test.globs)
      File "<doctest __main__.dS[5]>", line 1, in <module>
        ds = dS(nxkx, nTest.sum(), K)
      File "<ipython-input-142-6287d17fa824>", line 179, in dS
        return quadgl(f, [0, 1],maxdegree=20)/quadgl(g, [0, 1],maxdegree=20)
      File "/Users/cex/anaconda3/envs/jupyter2/lib/python3.7/site-packages/mpmath/calculus/quadrature.py", line 809, in quadgl
        return ctx.quad(*args, **kwargs)
      File "/Users/cex/anaconda3/envs/jupyter2/lib/python3.7/site-packages/mpmath/calculus/quadrature.py", line 742, in quad
        v, err = rule.summation(f, points[0], prec, epsilon, m, verbose)
      File "/Users/cex/anaconda3/envs/jupyter2/lib/python3.7/site-packages/mpmath/calculus/quadrature.py", line 232, in summation
        results.append(self.sum_next(f, nodes, degree, prec, results, verbose))
      File "/Users/cex/anaconda3/envs/jupyter2/lib/python3.7/site-packages/mpmath/calculus/quadrature.py", line 254, in sum_next
        return self.ctx.fdot((w, f(x)) for (x,w) in nodes)
      File "/Users/cex/anaconda3/envs/jupyter2/lib/python3.7/site-packages/mpmath/ctx_mp_python.py", line 944, in fdot
        for a, b in A:
      File "/Users/cems/anaconda3/envs/jupyter2/lib/python3.7/site-packages/mpmath/calculus/quadrature.py", line 254, in <genexpr>
        return self.ctx.fdot((w, f(x)) for (x,w) in nodes)
      File "<ipython-input-142-6287d17fa824>", line 177, in <lambda>
        f = lambda w: _dSi(w, nxkx, N, K)
      File "<ipython-input-142-6287d17fa824>", line 156, in _dSi
        return _rho(beta, nxkx, N, K) * _S2(beta, nxkx, N, K) * _dxi(beta, K) * 2*sbeta/(1-w)/(1-w)
      File "<ipython-input-142-6287d17fa824>", line 139, in _S2
        for x in nx:
    TypeError: iteration over a 0-d array
**********************************************************************
File "__main__", line 168, in __main__.dS
Failed example:
    ds
Exception raised:
    Traceback (most recent call last):
      File "/Users/cex/anaconda3/envs/jupyter2/lib/python3.7/doctest.py", line 1329, in __run
        compileflags, 1), test.globs)
      File "<doctest __main__.dS[6]>", line 1, in <module>
        ds
    NameError: name 'ds' is not defined
**********************************************************************
File "__main__", line 170, in __main__.dS
Failed example:
    sqrt(ds-s**2) # the standard deviation for the estimated entropy.
Exception raised:
    Traceback (most recent call last):
      File "/Users/cex/anaconda3/envs/jupyter2/lib/python3.7/doctest.py", line 1329, in __run
        compileflags, 1), test.globs)
      File "<doctest __main__.dS[7]>", line 1, in <module>
        sqrt(ds-s**2) # the standard deviation for the estimated entropy.
    NameError: name 'ds' is not defined
**********************************************************************
File "__main__", line 44, in __main__.make_nxkx
Failed example:
    make_nxkx(nTest, 9) # {0: 3, 2: 3, 3: 1, 4: 2}
Expected:
    {2: 3, 3: 1, 4: 2, 0: 3} 
Got:
    {2: 3, 3: 1, 4: 2, 0: 3}
**********************************************************************
2 items had failures:
   3 of   8 in __main__.dS
   1 of   3 in __main__.make_nxkx
***Test Failed*** 4 failures.

@cems2
Copy link

cems2 commented May 6, 2019

the problem is the code is not python3 ready. the issue is dict.keys() needs to be cast to a list and xrange needs to be range.

Finally the doctest for dictionaries is not stable either in order or white space.

here's is the revised code that runs under python3

#!/usr/bin/env python

"""
nsb_entropy.py

June, 2011 written by Sungho Hong, Computational Neuroscience Unit, Okinawa Institute of
Science and Technology
May 2019 updated to python3 by Charlie Strauss, Los Alamos National Lab

This script is a python version of Mathematica functions by Christian Mendl
implementing the Nemenman-Shafee-Bialek (NSB) estimator of entropy. For the
details of the method, check out the references below.

It depends on mpmath and numpy package.

References:
http://christian.mendl.net/pages/software.html
http://nsb-entropy.sourceforge.net
Ilya Nemenman, Fariel Shafee, and William Bialek. Entropy and Inference,
Revisited. arXiv:physics/0108025
Ilya Nemenman, William Bialek, and Rob de Ruyter van Steveninck. Entropy and
information in neural spike trains: Progress on the sampling problem. Physical
Review E 69, 056111 (2004)

"""
 
from mpmath import psi, rf, power, quadgl, mp, memoize
import numpy as np

DPS = 40

psi = memoize(psi)
rf = memoize(rf)
power = memoize(power)

def make_nxkx(n, K):
  """
  Return the histogram of the input histogram n assuming that the number of
  all bins is K.
  
    >>> from numpy import array
    >>> nTest = array([4, 2, 3, 0, 2, 4, 0, 0, 2])
    >>> make_nxkx(nTest, 9) ==  {0: 3, 2: 3, 3: 1, 4: 2}
    True
  """
  nxkx = {}
  nn = n[n>0]
  unn = np.unique(nn)
  for x in unn:
    nxkx[x] = (nn==x).sum()
  if K>nn.size:
    nxkx[0] = K-nn.size
  return nxkx

def _xi(beta, K):
  return psi(0, K*beta+1)-psi(0, beta+1)

def _dxi(beta, K):
  return K*psi(1, K*beta+1)-psi(1, beta+1)

def _S1i(x, nxkx, beta, N, kappa):
  return nxkx[x] * (x+beta)/(N+kappa) * (psi(0, x+beta+1)-psi(0, N+kappa+1))

def _S1(beta, nxkx, N, K):
  kappa = beta*K
  rx = np.array([_S1i(x, nxkx, beta, N, kappa) for x in nxkx])
  return -rx.sum()

def _rhoi(x, nxkx, beta):
  return power(rf(beta, np.double(x)), nxkx[x])

def _rho(beta, nxkx, N, K):  
  kappa = beta*K
  rx = np.array([_rhoi(x, nxkx, beta) for x in nxkx])
  return rx.prod()/rf(kappa, np.double(N))

def _Si(w, nxkx, N, K):
  sbeta = w/(1-w)
  beta = sbeta*sbeta
  return _rho(beta, nxkx, N, K) * _S1(beta, nxkx, N, K) * _dxi(beta, K) * 2*sbeta/(1-w)/(1-w)

def _measure(w, nxkx, N, K):
  sbeta = w/(1-w)
  beta = sbeta*sbeta
  return _rho(beta, nxkx, N, K) * _dxi(beta, K) * 2*sbeta/(1-w)/(1-w)

def S(nxkx, N, K):
  """
  Return the estimated entropy. nxkx is the histogram of the input histogram
  constructed by make_nxkx. N is the total number of samples, and K is the
  degree of freedom.
  
  >>> from numpy import array
  >>> nTest = array([4, 2, 3, 0, 2, 4, 0, 0, 2])
  >>> K = 9  # which is actually equal to nTest.size.
  >>> S(make_nxkx(nTest, K), nTest.sum(), K)
  1.940646728502697166844598206814653716492
  """
  mp.dps = DPS
  mp.pretty = True
  
  f = lambda w: _Si(w, nxkx, N, K)
  g = lambda w: _measure(w, nxkx, N, K)  
  return quadgl(f, [0, 1],maxdegree=20)/quadgl(g, [0, 1],maxdegree=20)

def _S2i_diag(x, nxkx, beta, N, kappa):
  xbeta = x+beta
  Nkappa2 = N+kappa+2
  psNK2 = psi(0, Nkappa2)
  ps1NK2 = psi(1, Nkappa2)
  
  s1  = (psi(0, xbeta+2)-psNK2)**2 + psi(1, xbeta+2) - ps1NK2
  s1 *= nxkx[x]*xbeta*(xbeta+1)
  
  s2 = (psi(0, xbeta+1)-psNK2)**2 - ps1NK2
  s2 *= nxkx[x]*(nxkx[x]-1)*xbeta*xbeta
  
  return s1+s2

def _S2i_nondiag(x1, x2, nxkx, beta, N, kappa):
  psNK2 = psi(0, N+kappa+2)
  ps1NK2 = psi(1, N+kappa+2)
  x1beta = x1+beta
  x2beta = x2+beta
  
  s1 = (psi(0, x1beta+1)-psNK2)*(psi(0, x2beta+1)-psNK2) - ps1NK2
  s1 = s1*nxkx[x1]*nxkx[x2]*x1beta*x2beta

  return s1

def _S2(beta, nxkx, N, K):
  kappa = beta*K
  Nkappa = N+kappa
  nx = np.array(list(nxkx.keys()))
  Nnx = nx.size
  
  dsum = 0.0
  for x in nx:
    dsum = dsum + _S2i_diag(x, nxkx, beta, N, kappa)
  
  ndsum = 0.0
  for i in range(Nnx-1):
    x1 = nx[i]
    for j in range(i+1, Nnx):
      x2 = nx[j]
      if x1==x2:
        ndsum = ndsum + _S2i_diag(x1, nxkx, beta, N, kappa)
      else:
        ndsum = ndsum + _S2i_nondiag(x1, x2, nxkx, beta, N, kappa)
  return (dsum + 2*ndsum)/(Nkappa)/(Nkappa+1)

def _dSi(w, nxkx, N, K):
  sbeta = w/(1-w)
  beta = sbeta*sbeta
  return _rho(beta, nxkx, N, K) * _S2(beta, nxkx, N, K) * _dxi(beta, K) * 2*sbeta/(1-w)/(1-w)

def dS(nxkx, N, K):
  """
  Return the mean squared flucuation of the entropy.
  
  >>> from numpy import array, sqrt
  >>> nTest = np.array([4, 2, 3, 0, 2, 4, 0, 0, 2])
  >>> K = 9  # which is actually equal to nTest.size.
  >>> nxkx = make_nxkx(nTest, K)
  >>> s = S(nxkx, nTest.sum(), K)
  >>> ds = dS(nxkx, nTest.sum(), K)
  >>> ds
  3.790453283682443237187782792251041316212
  >>> sqrt(ds-s**2) # the standard deviation for the estimated entropy.
  0.1560242251518078487118059349690693094484
  """  
  
  mp.dps = DPS
  mp.pretty = True
  
  f = lambda w: _dSi(w, nxkx, N, K)
  g = lambda w: _measure(w, nxkx, N, K)  
  return quadgl(f, [0, 1],maxdegree=20)/quadgl(g, [0, 1],maxdegree=20)

if __name__ == "__main__":
    pass
    import doctest
    doctest.testmod()

@shhong
Copy link
Author

shhong commented Jun 3, 2019

Thanks! Applied the changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment