Last active
September 10, 2020 14:16
-
-
Save khanh101/e4999dc29e4bae779a5d9f155c2e1a52 to your computer and use it in GitHub Desktop.
Simple Bayesian Inference for Binomial
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple, Iterator, Any | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib import animation | |
# from matplotlib.collections import PathCollection | |
from matplotlib.figure import Figure | |
from matplotlib.lines import Line2D | |
np.seterr(all='raise') | |
def draw_beta(a: float, b: float, sampling_count: int = 500) -> Tuple[np.ndarray, np.ndarray]: | |
interval = 1 / sampling_count | |
x = np.arange(interval / 2, 1, interval) | |
try: | |
p = (x ** (a - 1)) * ((1 - x) ** (b - 1)) | |
scale = p.sum() / sampling_count | |
p /= scale | |
return x, p | |
except FloatingPointError as err: | |
print(f"FloatingPointError: {err}") | |
exit(0) | |
def bernoulli_trial(q: float, sampling_count: int) -> int: | |
return (np.random.random(size=(sampling_count,)) < q).sum() | |
def update_beta(a: float, b: float, n: int, k: int) -> Tuple[float, float]: | |
a = k + a | |
b = n - k + b | |
return a, b | |
# global | |
a: float = 1.0 | |
b: float = 1.0 | |
q: float = 0.7 | |
fig: Figure | |
ax: Any | |
fig, ax = plt.subplots() | |
line: Line2D = ax.plot([], [], lw=1)[0] | |
# scat: PathCollection = ax.scatter([], []) | |
i: int = 0 | |
# end global | |
ax.grid() | |
def data_gen() -> Iterator[Tuple[np.ndarray, np.ndarray]]: | |
global a, b, i | |
input("enter to start ...") | |
while True: | |
# update beta | |
n = 1 | |
k = bernoulli_trial(q, n) | |
a, b = update_beta(a, b, n, k) | |
mean = a / (a + b) | |
std = np.sqrt((a * b) / (((a + b) ** 2) * (a + b + 1))) | |
i += 1 | |
print(f"iter {i} mean {mean} std {std}") | |
# end update beta | |
# sample | |
x, y = draw_beta(a, b) | |
# end sample | |
yield x, y | |
def init(): | |
ax.set_ylim(0.0, 1.0) | |
ax.set_xlim(0, 1) | |
def run(data: Tuple[np.ndarray, np.ndarray]): | |
x, y = data | |
while ax.get_ylim()[1] < y.max(initial=0): | |
ax.set_ylim(0.0, 2 * ax.get_ylim()[1]) | |
# scat.set_offsets(np.array([x, y]).T) | |
line.set_data(x, y) | |
ani = animation.FuncAnimation(fig, run, data_gen, blit=False, interval=10, repeat=False, init_func=init) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment