Skip to content

Instantly share code, notes, and snippets.

@sinisterra
Created August 8, 2019 21:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sinisterra/4003f945275c3626a40259374eb75fc4 to your computer and use it in GitHub Desktop.
Save sinisterra/4003f945275c3626a40259374eb75fc4 to your computer and use it in GitHub Desktop.
Causality with bayesian networks
from pomegranate import (
DiscreteDistribution,
ConditionalProbabilityTable,
JointProbabilityTable,
BayesianNetwork,
State,
)
season = DiscreteDistribution(
{"spring": 1.0 / 4, "summer": 1.0 / 4, "autumn": 1.0 / 4, "winter": 1.0 / 4}
)
rain = ConditionalProbabilityTable(
[
["spring", 1, 0.8],
["spring", 0, 0.2],
["summer", 1, 0.5],
["summer", 0, 0.5],
["autumn", 1, 0.6],
["autumn", 0, 0.4],
["winter", 1, 0.1],
["winter", 0, 0.9],
],
[season],
)
# rain -> sprinkler
sprinkler = ConditionalProbabilityTable(
[[0, 1, 0.4], [0, 0, 0.6], [1, 1, 0.01], [1, 0, 0.99]], [rain]
)
grass_wet = ConditionalProbabilityTable(
[
[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 1, 1, 0.8],
[0, 1, 0, 0.2],
[1, 0, 1, 0.9],
[1, 0, 0, 0.1],
[1, 1, 1, 0.99],
[1, 1, 0, 0.01],
],
[sprinkler, rain],
)
s0 = State(season, name="season")
s1 = State(rain, name="rain")
s2 = State(sprinkler, name="sprinkler")
s3 = State(grass_wet, name="grass_wet")
model = BayesianNetwork("Rain")
model.add_states(s0, s1, s2, s3)
model.add_edge(s0, s1)
model.add_edge(s1, s2)
model.add_edge(s2, s3)
model.add_edge(s1, s3)
model.bake()
p = model.predict_proba({"season": "winter", "rain": 1})
print([e for e in zip(["season", "rain", "sprinkler", "grass_wet"], p)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment