Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
import pymc3, numpy, sys, seaborn, re
def get_dist(fn):
y = [0, 0, 0, 0, 0]
for line in open(fn):
try:
num = re.split('\D', line)[0]
y[int(num) - 1] += 1
except:
print fn, 'can not parse:', line
print y
k = len(y)
n = sum(y)
model = pymc3.Model()
with model:
p = pymc3.Dirichlet('probs', a=numpy.array([1.0] * k), shape=k)
data = pymc3.Multinomial('data', n=numpy.array([[n]]), p=p, observed=y)
avg = pymc3.Deterministic('avg', pymc3.dot(p, [1, 2, 3, 4, 5]))
trace = pymc3.sample(4000, pymc3.Slice())
return trace[avg]
seaborn.plt.figure(figsize=(10, 10))
for fn in sys.argv[1:]:
avg = get_dist(fn)
seaborn.distplot(avg, label=fn)
seaborn.plt.legend()
seaborn.plt.savefig('ratings_mcmc.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment