Skip to content

Instantly share code, notes, and snippets.

@khanh101
Last active September 10, 2020 14:16
Show Gist options
  • Save khanh101/e4999dc29e4bae779a5d9f155c2e1a52 to your computer and use it in GitHub Desktop.
Save khanh101/e4999dc29e4bae779a5d9f155c2e1a52 to your computer and use it in GitHub Desktop.
Simple Bayesian Inference for Binomial
from typing import Tuple, Iterator, Any
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
# from matplotlib.collections import PathCollection
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
np.seterr(all='raise')
def draw_beta(a: float, b: float, sampling_count: int = 500) -> Tuple[np.ndarray, np.ndarray]:
interval = 1 / sampling_count
x = np.arange(interval / 2, 1, interval)
try:
p = (x ** (a - 1)) * ((1 - x) ** (b - 1))
scale = p.sum() / sampling_count
p /= scale
return x, p
except FloatingPointError as err:
print(f"FloatingPointError: {err}")
exit(0)
def bernoulli_trial(q: float, sampling_count: int) -> int:
return (np.random.random(size=(sampling_count,)) < q).sum()
def update_beta(a: float, b: float, n: int, k: int) -> Tuple[float, float]:
a = k + a
b = n - k + b
return a, b
# global
a: float = 1.0
b: float = 1.0
q: float = 0.7
fig: Figure
ax: Any
fig, ax = plt.subplots()
line: Line2D = ax.plot([], [], lw=1)[0]
# scat: PathCollection = ax.scatter([], [])
i: int = 0
# end global
ax.grid()
def data_gen() -> Iterator[Tuple[np.ndarray, np.ndarray]]:
global a, b, i
input("enter to start ...")
while True:
# update beta
n = 1
k = bernoulli_trial(q, n)
a, b = update_beta(a, b, n, k)
mean = a / (a + b)
std = np.sqrt((a * b) / (((a + b) ** 2) * (a + b + 1)))
i += 1
print(f"iter {i} mean {mean} std {std}")
# end update beta
# sample
x, y = draw_beta(a, b)
# end sample
yield x, y
def init():
ax.set_ylim(0.0, 1.0)
ax.set_xlim(0, 1)
def run(data: Tuple[np.ndarray, np.ndarray]):
x, y = data
while ax.get_ylim()[1] < y.max(initial=0):
ax.set_ylim(0.0, 2 * ax.get_ylim()[1])
# scat.set_offsets(np.array([x, y]).T)
line.set_data(x, y)
ani = animation.FuncAnimation(fig, run, data_gen, blit=False, interval=10, repeat=False, init_func=init)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment