Skip to content

Instantly share code, notes, and snippets.

@hkalexling

hkalexling/em.py Secret

Created October 6, 2017 12:10
Show Gist options
  • Save hkalexling/8b97806017cb7cd4ad4937ec1deb157b to your computer and use it in GitHub Desktop.
Save hkalexling/8b97806017cb7cd4ad4937ec1deb157b to your computer and use it in GitHub Desktop.
import numpy as np
import math
l = [229, 211, 93, 35, 7, 1]
a = [325, 115, 67, 30, 18, 21]
n = 576
k = 0
threshole = 1e-6
x_l = None
x_a = None
lamb = None
pi = None
def init_x():
global x_l, x_a
x_l = []
x_a = []
for k, num in enumerate(l):
for i in range(num):
x_l.append(k)
for k, num in enumerate(a):
for i in range(num):
x_a.append(k)
x_l = np.array(x_l)
x_a = np.array(x_a)
def init_para(_k):
global lamb, pi, k
k = _k
lamb = np.ones(k)
for j in range(k):
lamb[j] = j + 1
pi = np.ones(k) * (1/k)
def E(x):
z = np.zeros((n, k))
for i in range(n):
v = pi * (1 / math.factorial(x[i])) * np.exp(-1 * lamb) * (lamb ** x[i])
numerator = np.sum(v)
c = v / numerator
z[i, :] = c
return z
def M(z, x):
global lamb, pi
numerators = np.sum(z, 0)
denominators = (x.T @ z).T
lamb = denominators / numerators
pi = numerators / n
def EM(x):
iteration = 0
while True:
z = E(x)
l = np.array(lamb)
p = np.array(pi)
M(z, x)
iteration += 1
delta_lamb = np.linalg.norm(lamb - l, 2)
delta_pi = np.linalg.norm(pi - p, 2)
if delta_pi < threshole and delta_lamb < threshole:
print('-' * 80)
print('EM ended. k = {}, iteration: {}'.format(k, iteration))
print('Lambda: {}'.format(lamb))
print('Pi: {}'.format(pi))
break
def p(x):
v = 0
for j in range(k):
v += pi[j] * (1 / math.factorial(x)) * math.exp(-1 * lamb[j]) * lamb[j] ** x
return v
def main(name, x, true):
print('=' * 80)
print(name + '\n')
for i in range(1, 6):
init_para(i)
EM(x)
print('\n')
for j in range(6):
p_ = p(j)
t = true[j] / n
print('k = {}; predicted: {}; true: {}'.format(j, p_, t))
print('')
if __name__ == '__main__':
init_x()
main('London', x_l, l)
main('Antwerp', x_a, a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment