Created
June 16, 2016 22:05
-
-
Save rbaron/1bfb3c0e97449f2bb0ed8fd84ca1ab62 to your computer and use it in GitHub Desktop.
EM example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script is an example of learning parameters of a Bayesian network | |
with hidden variables. The following Bayesian net is analyzed: | |
(hidden) (observable) | |
K -----------> A | |
Where K (knowledge) is not observable and A (answers), which depends on K, is. | |
The sample consist of only observations of A. The goal is to estimate the | |
probability P(K = True) = theta, using the expectation maximization (EM) | |
algorithm. | |
The following CPD if assumed for P(A | K): | |
P(A = True | K = True ) = .9 | |
P(A = False | K = True ) = .1 (Chance of slip: mark the wrong answer even though knowing it) | |
P(A = True | K = False ) = .2 (Chance of guess: marking the right answer without knowing it) | |
P(A = False | K = False ) = .8 | |
""" | |
import random | |
# The real theta, which will be used to generate the sample. | |
# The algorithm will try to estimate this parameter. | |
REAL_THETA = 0.3 | |
P_AT_GIVEN_KT = 1.0 | |
P_AT_GIVEN_KF = 0.2 | |
INITIAL_THETA_GUESS = 0.001 | |
N_SAMPLES = 100000 | |
TOL = 0.000001 | |
# Sample consists of observations of A | |
sample = [ | |
random.random() < (P_AT_GIVEN_KT if k == True else P_AT_GIVEN_KF) | |
for k in ((random.random() < REAL_THETA) for i in range(N_SAMPLES)) | |
] | |
# In this simple case we can also estimate the theta in one single step | |
N_AT = len(list(filter(lambda a: a, sample))) | |
N = len(sample) | |
theta_anal = 5*(N_AT/N - .2)/4 | |
print("Analytical solution: theta = {}".format(theta_anal)) | |
P_AF_GIVEN_KT = 1 - P_AT_GIVEN_KT | |
P_AF_GIVEN_KF = 1 - P_AT_GIVEN_KF | |
new_theta = INITIAL_THETA_GUESS | |
theta = -1 | |
acc = [new_theta] | |
while abs(theta - new_theta) > TOL: | |
theta = new_theta | |
# E-Step: Estimate Nk: the number of datapoints atributed to K = True, | |
# using the current estimate for theta = p(K = True) | |
p_kt_given_af = P_AF_GIVEN_KT*theta/(P_AF_GIVEN_KT*theta + P_AF_GIVEN_KF*(1 - theta)) | |
p_kt_given_at = P_AT_GIVEN_KT*theta/(P_AT_GIVEN_KT*theta + P_AT_GIVEN_KF*(1 - theta)) | |
Nk = sum((p_kt_given_at if a == True else p_kt_given_af for a in sample)) | |
# M-Step: MLE for theta | |
new_theta = Nk/len(sample) | |
acc.append(new_theta) | |
print("Next theta: {:2.6f}".format(new_theta)) | |
print("Learned model. P(K = True) = theta = {:2.6f}".format(new_theta)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
% python em_simple.py | |
Analytical solution: theta = 0.30017499999999997 | |
Next theta: 0.002192 | |
Next theta: 0.004782 | |
Next theta: 0.010326 | |
Next theta: 0.021823 | |
Next theta: 0.044170 | |
Next theta: 0.082609 | |
Next theta: 0.136646 | |
Next theta: 0.194439 | |
Next theta: 0.240698 | |
Next theta: 0.269873 | |
Next theta: 0.285603 | |
Next theta: 0.293373 | |
Next theta: 0.297046 | |
Next theta: 0.298745 | |
Next theta: 0.299523 | |
Next theta: 0.299879 | |
Next theta: 0.300040 | |
Next theta: 0.300114 | |
Next theta: 0.300147 | |
Next theta: 0.300162 | |
Next theta: 0.300169 | |
Next theta: 0.300172 | |
Next theta: 0.300174 | |
Next theta: 0.300174 | |
Learned model. P(K = True) = theta = 0.300174 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment