Skip to content

Instantly share code, notes, and snippets.

@erikbern
Last active February 14, 2018 05:05
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save erikbern/6a79c41384b217ddc097 to your computer and use it in GitHub Desktop.
Save erikbern/6a79c41384b217ddc097 to your computer and use it in GitHub Desktop.
import pymc3, numpy, sys, seaborn, re
def get_dist(fn):
y = [0, 0, 0, 0, 0]
for line in open(fn):
try:
num = re.split('\D', line)[0]
y[int(num) - 1] += 1
except:
print fn, 'can not parse:', line
print y
k = len(y)
n = sum(y)
model = pymc3.Model()
with model:
p = pymc3.Dirichlet('probs', a=numpy.array([1.0] * k), shape=k)
data = pymc3.Multinomial('data', n=numpy.array([[n]]), p=p, observed=y)
avg = pymc3.Deterministic('avg', pymc3.dot(p, [1, 2, 3, 4, 5]))
trace = pymc3.sample(4000, pymc3.Slice())
return trace[avg]
seaborn.plt.figure(figsize=(10, 10))
for fn in sys.argv[1:]:
avg = get_dist(fn)
seaborn.distplot(avg, label=fn)
seaborn.plt.legend()
seaborn.plt.savefig('ratings_mcmc.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment