Last active
April 13, 2017 20:44
-
-
Save kghose/fb997a32150bbea6573256a2ebb6a2a4 to your computer and use it in GitHub Desktop.
Simple simulator to explore conditional independence in Bayesian networks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Simple simulator to explore conditional independence in Bayesian networks | |
Node representation: | |
Say the node is has two parents named L and K and the probability table is | |
L | K | p(T) | |
-------------- | |
F | F | 0.2 | |
F | T | 0.4 | |
T | F | 0.7 | |
T | T | 0.9 | |
The node is represented as the following dictionary | |
{ | |
'parents': ['L', 'K'], | |
'table': { | |
'FF': 0.2, | |
'FT': 0.4, | |
'TF': 0.7, | |
'TT': 0.9 | |
} | |
} | |
For a node with no parents we have | |
{ | |
'parents': [], | |
'table': { | |
'X': 0.2, | |
} | |
} | |
A network is represented as a dictionary with keys being the node names and values being the | |
nodes represented as the dictionaries described above | |
So a simple two element network A -> B with probability tables | |
A p(T) = 0.4 | |
A | B p(T) | |
------------ | |
T | 0.2 | |
F | 0.6 | |
Is represented as | |
{ | |
'A': { | |
'parents': [], | |
'table': { | |
'X': 0.4 | |
} | |
}, | |
'B': { | |
'parents': ['A'], | |
'table': { | |
'T': 0.2, | |
'F': 0.6 | |
} | |
} | |
} | |
""" | |
import unittest | |
from collections import OrderedDict | |
import numpy as np | |
from scipy.stats import fisher_exact | |
def validate_network(net): | |
"""Simple sanity checks on the network""" | |
assert isinstance(net, OrderedDict), 'Network must be OrderedDict' | |
for k, v in net.items(): | |
if len(v['parents']): | |
for p in v['parents']: | |
assert p in net, 'Node {} has parent {} not in network'.format(k, p) | |
for k, v in net.items(): | |
cycle_check(k, k, net) | |
for k, v in net.items(): | |
for p in v['table'].values(): | |
assert 0 <= p <= 1.0, 'Node {} has out of range probability in table'.format(k) | |
def cycle_check(descendant, ancestor, network): | |
"""Given the original descendant and an ancestor, recursively go back along the network to | |
see if we ever run into the ancestor again | |
:param descendant: | |
:param ancestor: | |
:param network: | |
:return: | |
""" | |
for p in network[ancestor]['parents']: | |
if p == descendant: | |
raise AssertionError('Cycle in given network') | |
else: | |
cycle_check(descendant, p, network) | |
def get_p_true(node, inputs=None): | |
"""Given a node and inputs, compute the probbability of the node going to True state | |
:param node: | |
:param inputs: If the node parents are L and K, the inputs should look like | |
{'L': 'T', 'K': 'F'} and so on | |
:return: | |
""" | |
key = ''.join(inputs[k] for k in node['parents']) if len(node['parents']) > 0 else 'X' | |
return node['table'][key] | |
def get_node_value(node_name, network, network_state): | |
"""Given a network compute the value of the given node, recursively computing | |
the values of parent nodes as needed | |
:param node_name | |
:param network: | |
:param network_state: A dictionary with keys as node names and values as node states | |
:param r is an array of random numbers same size as the network node count that determines | |
node state | |
:return: | |
""" | |
if node_name not in network_state: | |
this_node = network[node_name] | |
pt = get_p_true( | |
this_node, | |
inputs={p: get_node_value(p, network, network_state) | |
for p in this_node['parents']} | |
) | |
network_state[node_name] = 'T' if np.random.rand() < pt else 'F' | |
return network_state[node_name] | |
def run_network(network, start_nodes): | |
"""Simulate the network | |
:param network: | |
:param start_nodes: nodes which when computed, will result in the computation of all other nodes | |
:return: a dictionary of {node name: node state} | |
""" | |
network_state = {} | |
for node in start_nodes: | |
get_node_value(node, network, network_state) | |
return network_state | |
def run_simulations(network, start_nodes, runs=1000): | |
"""Simulate the network repeatedly | |
:param network: | |
:param start_nodes: | |
:param runs: | |
:return: | |
""" | |
def ns_to_tuple(ns): | |
return tuple(-1 if ns[k] == 'F' else 1 for k in network.keys()) | |
return np.array( | |
[ns_to_tuple(run_network(network, start_nodes)) | |
for _ in range(runs)], | |
dtype=[(k, int) for k in network.keys()]) | |
def compute_2x2(sim, na, nb, no=None): | |
"""Compute the 2x2 matrix for nodes A and B, with or without O being the observed node | |
:param sim: | |
:param na: | |
:param nb: | |
:param no: | |
:return: | |
""" | |
idx = (sim[no[0]] == no[1]) if no is not None else np.ones(len(sim[na]), dtype=bool) | |
return np.array( | |
[ | |
[((sim[na][idx] == A) & (sim[nb][idx] == B)).sum() for B in [1, -1]] | |
for A in [1, -1] | |
] | |
) | |
def chi_square_statistic(x): | |
"""Given a 2x2 matrix, compute the chi-square statistic on it | |
:param x: | |
:return: | |
""" | |
expected_x = np.array([ | |
[(float(x[:, 0].sum()) / x.sum()) * x[0, :].sum(), | |
(float(x[:, 1].sum()) / x.sum()) * x[0, :].sum()], | |
[(float(x[:, 0].sum()) / x.sum()) * x[1, :].sum(), | |
(float(x[:, 1].sum()) / x.sum()) * x[1, :].sum()], | |
]) | |
return ((x - expected_x)**2 / expected_x).sum() | |
def phi(x): | |
"""Given a 2x2 matrix, compute the phi statistic on it | |
:param x: | |
:return: | |
""" | |
return (x[0, 0] * x[1, 1] - x[0, 1] * x[1, 0]) / (x[0, :].sum() * x[1, :].sum() * x[:, 0].sum() * x[:, 1].sum()) ** 0.5 | |
def probe_nodes(sim, na, nb, no=None): | |
"""Given a simulation and two nodes (and possibly an observed node) return us the chi-square and | |
the chi-square statistic | |
If an observed variable (no) is supplied, return us two chi-square statistics that correspond to | |
when no='T' and no='F' | |
:param sim: | |
:param na: | |
:param nb: | |
:param no: | |
:return: | |
""" | |
chi2x2 = [ | |
compute_2x2(sim, na, nb, _no) | |
for _no in ([(no, -1), (no, 1)] if no is not None else [None]) | |
] | |
return chi2x2, [chi_square_statistic(ch) for ch in chi2x2] | |
def pretty_print_probe_nodes(sim, na, nb, nobs=None, method='phi'): | |
""" | |
:param sim: | |
:param na: | |
:param nb: | |
:param nobs: | |
:param method: One of 'chi', 'phi', 'fisher' | |
:return: | |
""" | |
chi2x2 = [ | |
compute_2x2(sim, na, nb, _no) | |
for _no in ([(nobs, -1), (nobs, 1)] if nobs is not None else [None]) | |
] | |
if method == 'chi': | |
stat = [chi_square_statistic(ch) for ch in chi2x2] | |
elif method == 'phi': | |
stat = [phi(ch) for ch in chi2x2] | |
elif method == 'fisher': | |
stat = [fisher_exact(ch)[1] for ch in chi2x2] | |
print('{} & {} {}: {}'.format( | |
na, nb, 'with {} observed'.format(nobs) if nobs is not None else '', | |
stat | |
)) | |
class TestNetwork(unittest.TestCase): | |
def test_network_validation(self): | |
"""Network validation""" | |
with self.assertRaises(AssertionError): | |
validate_network({'A': {'parents': ['B']}, 'B': {'parents': ['A']}}) | |
with self.assertRaises(AssertionError): | |
validate_network(OrderedDict({'A': {'parents': ['B']}, 'B': {'parents': ['A']}})) | |
with self.assertRaises(AssertionError): | |
validate_network(OrderedDict({'A': {'parents': ['B']}, 'B': {'parents': ['C']}})) | |
with self.assertRaises(AssertionError): | |
validate_network(OrderedDict({'A': {'parents': ['B'], 'table': {'X': 1.1}}, | |
'B': {'parents': [], 'table': {'X': 0.2}}})) | |
validate_network(party_network()[0]) | |
def three_node_head_to_tail_unrelated(): | |
""" | |
A -> C -> B | |
:return: | |
""" | |
return OrderedDict({ | |
'A': { | |
'parents': [], | |
'table': { | |
'X': 0.5 | |
} | |
}, | |
'C': { | |
'parents': ['A'], | |
'table': { | |
'T': 0.7, | |
'F': 0.3 | |
} | |
}, | |
'B': { | |
'parents': ['C'], | |
'table': { | |
'T': 0.5, | |
'F': 0.5 | |
} | |
} | |
}), ['B'] | |
def three_node_head_to_tail_related(): | |
""" | |
A -> C -> B | |
:return: | |
""" | |
return OrderedDict({ | |
'A': { | |
'parents': [], | |
'table': { | |
'X': 0.5 | |
} | |
}, | |
'C': { | |
'parents': ['A'], | |
'table': { | |
'T': 0.7, | |
'F': 0.3 | |
} | |
}, | |
'B': { | |
'parents': ['C'], | |
'table': { | |
'T': 0.7, | |
'F': 0.3 | |
} | |
} | |
}), ['B'] | |
def three_node_tail_to_tail(): | |
""" | |
A <- C -> B | |
:return: | |
""" | |
return OrderedDict({ | |
'A': { | |
'parents': ['C'], | |
'table': { | |
'T': 0.7, | |
'F': 0.3 | |
} | |
}, | |
'C': { | |
'parents': [], | |
'table': { | |
'X': 0.5 | |
} | |
}, | |
'B': { | |
'parents': ['C'], | |
'table': { | |
'T': 0.7, | |
'F': 0.3 | |
} | |
} | |
}), ['A', 'B'] | |
def three_node_head_to_head(): | |
""" | |
A -> C <- B | |
:return: | |
""" | |
return OrderedDict({ | |
'A': { | |
'parents': [], | |
'table': { | |
'X': 0.5 | |
} | |
}, | |
'C': { | |
'parents': ['A', 'B'], | |
'table': { | |
'FF': 0.1, | |
'FT': 0.7, | |
'TF': 0.3, | |
'TT': 0.9 | |
} | |
}, | |
'B': { | |
'parents': [], | |
'table': { | |
'X': 0.5 | |
} | |
} | |
}), ['C'] | |
def run_three_node_examples(runs=10000): | |
for netw in [ | |
(three_node_head_to_tail_unrelated, 'A -> C -> B Unrelated'), | |
(three_node_head_to_tail_related, 'A -> C -> B Related'), | |
(three_node_tail_to_tail, 'A <- C -> B'), | |
(three_node_head_to_head, 'A -> C <- B') | |
]: | |
net, start_nodes = netw[0]() | |
sim = run_simulations(net, start_nodes, runs=runs) | |
print(netw[1]) | |
pretty_print_probe_nodes(sim, 'A', 'B') | |
pretty_print_probe_nodes(sim, 'A', 'B', 'C') | |
def party_network(): | |
"""This sets up a the 7 node kid's party network""" | |
return OrderedDict({ | |
'E': { | |
'parents': [], | |
'table': { | |
'X': 0.1 | |
} | |
}, | |
'K': { | |
'parents': ['P'], | |
'table': { | |
'T': 0.7, | |
'F': 0.1 | |
} | |
}, | |
'P': { | |
'parents': [], | |
'table': { | |
'X': 0.1 | |
} | |
}, | |
'F': { | |
'parents': ['E', 'K'], | |
'table': { | |
'FF': 0.1, | |
'FT': 0.5, | |
'TF': 0.6, | |
'TT': 0.8 | |
} | |
}, | |
'D': { | |
'parents': ['P'], | |
'table': { | |
'T': 0.8, | |
'F': 0.4 | |
} | |
}, | |
'G': { | |
'parents': ['F', 'D'], | |
'table': { | |
'FF': 0.1, | |
'FT': 0.4, | |
'TF': 0.6, | |
'TT': 0.8 | |
} | |
}, | |
'C': { | |
'parents': ['G'], | |
'table': { | |
'T': 0.4, | |
'F': 0.4 | |
} | |
} | |
}), ['C'] # returns network and the start_node list | |
def run_party_example(runs=100000): | |
print('Party network') | |
net, start_nodes = party_network() | |
sim = run_simulations(net, start_nodes, runs=runs) | |
pretty_print_probe_nodes(sim, 'K', 'D') | |
pretty_print_probe_nodes(sim, 'K', 'D', 'P') | |
pretty_print_probe_nodes(sim, 'F', 'D') | |
pretty_print_probe_nodes(sim, 'F', 'D', 'G') | |
pretty_print_probe_nodes(sim, 'F', 'D') | |
pretty_print_probe_nodes(sim, 'F', 'D', 'C') | |
pretty_print_probe_nodes(sim, 'P', 'G') | |
pretty_print_probe_nodes(sim, 'P', 'G', 'D') | |
pretty_print_probe_nodes(sim, 'E', 'G') | |
pretty_print_probe_nodes(sim, 'E', 'G', 'F') | |
if __name__ == '__main__': | |
# unittest.main() | |
run_three_node_examples(runs=100000) | |
run_party_example(runs=100000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment