Skip to content

Instantly share code, notes, and snippets.

@tabacof
Last active August 29, 2015 14:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tabacof/c72a2bc507e6ccd97ca1 to your computer and use it in GitHub Desktop.
Save tabacof/c72a2bc507e6ccd97ca1 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import numpy as np
import pylab
from scipy.special import binom
from bayespy.nodes import Categorical, Binomial, Gate, Beta
#Jaynes' PT:TLoS example 4.1
def db(x): # Decibel transform
return 10.0*np.log10(x)
def evidence(x): # Equation 4.8
return db(1E-100 + x/(1.0 - x + 1E-100))
def b(r, n, f): # Equation 3.86
return binom(n, r)*np.power(f, r)*np.power(1.0 - f, n - r)
# Boxes prior - equation 4.31
priorA = 1.0/11*(1 - 1e-6)
priorB = 10.0/11*(1 - 1e-6)
priorC = 1.0 - priorA - priorB
priors = [priorA, priorB, priorC]
print("Priors sum: ", sum(priors))
# Boxes defect rate
boxA = 1.0/3.0
boxB = 1.0/6.0
boxC = 99.0/100.0
boxes = [boxA, boxB, boxC]
def real_posterior(r, n): # Equations 4.33 and 4.39
pl = []
for index, pr in enumerate(priors):
pl.append(pr*b(r, n, boxes[index]))
e = []
for index, pr in enumerate(priors):
e.append(evidence(pr) + db(b(r, n, boxes[index])*(sum(priors[:index] + priors[index + 1:])) / (sum(pl[:index] + pl[index + 1:])) ))
return e
minN = 1
assert(minN > 0)
maxN = 20
assert(maxN >= minN)
def bcalc(m, b):
return (m*b - 2.0*m + 1.0)/(1.0 - m)
# PyMC binomial doesn't work properly for N = 0
# so the evidence vectors must be initialized here
realPlot = []
varPlot = []
for i in priors:
realPlot.append([evidence(i)])
varPlot.append([evidence(i)])
for N in range(minN, maxN + 1):
# Observtion
NBad = N # Number of bad ones
assert(NBad <= N) # For when it's not always bad
print("Number of tests:", N)
box = Categorical(priors)
bb = 10000.0
variationalPriors = [[bcalc(i, bb), bb] for i in boxes]
beta = Beta(variationalPriors, plates = (len(variationalPriors),) )
box_func = Gate(box, beta)
bad = Binomial(N, box_func)
bad.observe(NBad)
box.update()
e = real_posterior(NBad, N)
for index, ev in enumerate(e):
realPlot[index].append(ev)
varPlot[index].append(evidence(box.pdf(index)))
fig = pylab.figure()
pylab.title('Variational approximation')
for i in varPlot:
pylab.plot(i)
ax = fig.gca()
ax.set_xticks(np.arange(minN - 1, maxN + 1, 1))
ax.set_yticks(np.arange(-60, 20, 3))
pylab.ylabel('Evidence (dB)')
pylab.ylim([-60, 20])
pylab.grid()
pylab.legend([str(i) for i in list(range(len(priors)))])
pylab.show()
fig = pylab.figure()
pylab.title('Jaynes\' Analytical Answer')
for i in realPlot:
pylab.plot(i)
ax = fig.gca()
ax.set_xticks(np.arange(minN - 1, maxN + 1, 1))
ax.set_yticks(np.arange(-60, 20, 3))
pylab.ylabel('Evidence (dB)')
pylab.ylim([-60, 20])
pylab.grid()
pylab.legend([str(i) for i in list(range(len(priors)))])
pylab.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment