Skip to content

Instantly share code, notes, and snippets.

@j2kun
Last active February 5, 2022 17:11
Show Gist options
  • Save j2kun/e0387f56d3749b3f0116cf5683dfd735 to your computer and use it in GitHub Desktop.
Save j2kun/e0387f56d3749b3f0116cf5683dfd735 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
import numpy as np
import pymc3 as pm
@dataclass
class SkillDistribution:
'''A class representing the data of a skill distribution.
Used to represent both the prior and posterior distribution of individual
performance, in both cases a Normal distribution.
'''
mean: float = 1500
stddev: float = 10
def define_model(p1_prior, p2_prior, perf_stddev):
# Starting with Elo for simplicity, to get something working
model = pm.Model()
with model:
p1_skill = pm.Normal("p1_skill", mu=p1_prior.mean, sigma=p1_prior.stddev)
p2_skill = pm.Normal("p2_skill", mu=p2_prior.mean, sigma=p2_prior.stddev)
p1_perf = pm.Normal("p1_perf", mu=p1_skill, sigma=perf_stddev)
p2_perf = pm.Normal("p2_perf", mu=p2_skill, sigma=perf_stddev)
perf_diff = p1_perf - p2_perf
pm.Bernoulli(
"p1_win", logit_p=perf_diff, observed=np.ones(1)
)
return model
if __name__ == "__main__":
p1_prior = SkillDistribution(mean=1200, stddev=50)
p2_prior = SkillDistribution(mean=2000, stddev=50)
print(f"P1: {p1_prior.mean}, P2: {p2_prior.mean}")
for i in range(10):
print(f"-------------\nRound {i}\n-------------")
model = define_model(p1_prior, p2_prior, 10)
outcomes = pm.find_MAP(model=model)
# posterior is new prior
p1_prior = SkillDistribution(mean=outcomes['p1_skill'], stddev=50)
p2_prior = SkillDistribution(mean=outcomes['p2_skill'], stddev=50)
print(f"P1: {p1_prior.mean}, P2: {p2_prior.mean}")
P1: 1200, P2: 2000
-------------
Round 0
-------------
|███████████████████████████████████████████████████████████████████████████████████████████| 100.00% [21/21 00:00<00:00 logp = -78.073, ||grad|| = 3.0609e-06]
P1: 1585.4337436308458, P2: 1614.5662563691542
-------------
Round 1
-------------
|███████████████████████████████████████████████████████████████████████████████████████████| 100.00% [12/12 00:00<00:00 logp = -16.224, ||grad|| = 0.00014347]
P1: 1601.8522709827512, P2: 1598.1477290172488
-------------
Round 2
-------------
|███████████████████████████████████████████████████████████████████████████████████████████████| 100.00% [9/9 00:00<00:00 logp = -16.129, ||grad|| = 0.03397]]
P1: 1603.566491445544, P2: 1596.433508554456
-------------
Round 3
-------------
|█████████████████████████████████████████████████████████████████████████████████████████████| 100.00% [8/8 00:00<00:00 logp = -16.106, ||grad|| = 0.0011281]]
P1: 1604.1523779209522, P2: 1595.8476220790478
-------------
Round 4
-------------
|█████████████████████████████████████████████████████████████████████████████████████████████| 100.00% [7/7 00:00<00:00 logp = -16.105, ||grad|| = 0.0003497]]
P1: 1604.4674231731594, P2: 1595.5325768268406
-------------
Round 5
-------------
|████████████████████████████████████████████████████████████████████████████████████████████| 100.00% [6/6 00:00<00:00 logp = -16.105, ||grad|| = 0.00018625]]
P1: 1604.6671121661452, P2: 1595.3328878338548
-------------
Round 6
-------------
|████████████████████████████████████████████████████████████████████████████████████████████| 100.00% [6/6 00:00<00:00 logp = -16.105, ||grad|| = 0.00012493]]
P1: 1604.820195609026, P2: 1595.179804390974
-------------
Round 7
-------------
|████████████████████████████████████████████████████████████████████████████████████████████| 100.00% [6/6 00:00<00:00 logp = -16.105, ||grad|| = 9.1985e-05]]
P1: 1604.9425464641267, P2: 1595.0574535358733
-------------
Round 8
-------------
|█████████████████████████████████████████████████████████████████████████████████████████████| 100.00% [6/6 00:00<00:00 logp = -16.105, ||grad|| = 7.202e-05]]
P1: 1605.0436597872829, P2: 1594.9563402127171
-------------
Round 9
-------------
|████████████████████████████████████████████████████████████████████████████████████████████| 100.00% [6/6 00:00<00:00 logp = -16.105, ||grad|| = 5.8834e-05]]
P1: 1605.129436618363, P2: 1594.870563381637
pymc3==3.11.4
numpy==1.20.2
scipy==1.7.3
Theano-PyMC==1.1.2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment