Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Simple simulator to explore conditional independence in Bayesian networks
"""Simple simulator to explore conditional independence in Bayesian networks
Node representation:
Say the node is has two parents named L and K and the probability table is
L | K | p(T)
--------------
F | F | 0.2
F | T | 0.4
T | F | 0.7
T | T | 0.9
The node is represented as the following dictionary
{
'parents': ['L', 'K'],
'table': {
'FF': 0.2,
'FT': 0.4,
'TF': 0.7,
'TT': 0.9
}
}
For a node with no parents we have
{
'parents': [],
'table': {
'X': 0.2,
}
}
A network is represented as a dictionary with keys being the node names and values being the
nodes represented as the dictionaries described above
So a simple two element network A -> B with probability tables
A p(T) = 0.4
A | B p(T)
------------
T | 0.2
F | 0.6
Is represented as
{
'A': {
'parents': [],
'table': {
'X': 0.4
}
},
'B': {
'parents': ['A'],
'table': {
'T': 0.2,
'F': 0.6
}
}
}
"""
import unittest
from collections import OrderedDict
import numpy as np
from scipy.stats import fisher_exact
def validate_network(net):
"""Simple sanity checks on the network"""
assert isinstance(net, OrderedDict), 'Network must be OrderedDict'
for k, v in net.items():
if len(v['parents']):
for p in v['parents']:
assert p in net, 'Node {} has parent {} not in network'.format(k, p)
for k, v in net.items():
cycle_check(k, k, net)
for k, v in net.items():
for p in v['table'].values():
assert 0 <= p <= 1.0, 'Node {} has out of range probability in table'.format(k)
def cycle_check(descendant, ancestor, network):
"""Given the original descendant and an ancestor, recursively go back along the network to
see if we ever run into the ancestor again
:param descendant:
:param ancestor:
:param network:
:return:
"""
for p in network[ancestor]['parents']:
if p == descendant:
raise AssertionError('Cycle in given network')
else:
cycle_check(descendant, p, network)
def get_p_true(node, inputs=None):
"""Given a node and inputs, compute the probbability of the node going to True state
:param node:
:param inputs: If the node parents are L and K, the inputs should look like
{'L': 'T', 'K': 'F'} and so on
:return:
"""
key = ''.join(inputs[k] for k in node['parents']) if len(node['parents']) > 0 else 'X'
return node['table'][key]
def get_node_value(node_name, network, network_state):
"""Given a network compute the value of the given node, recursively computing
the values of parent nodes as needed
:param node_name
:param network:
:param network_state: A dictionary with keys as node names and values as node states
:param r is an array of random numbers same size as the network node count that determines
node state
:return:
"""
if node_name not in network_state:
this_node = network[node_name]
pt = get_p_true(
this_node,
inputs={p: get_node_value(p, network, network_state)
for p in this_node['parents']}
)
network_state[node_name] = 'T' if np.random.rand() < pt else 'F'
return network_state[node_name]
def run_network(network, start_nodes):
"""Simulate the network
:param network:
:param start_nodes: nodes which when computed, will result in the computation of all other nodes
:return: a dictionary of {node name: node state}
"""
network_state = {}
for node in start_nodes:
get_node_value(node, network, network_state)
return network_state
def run_simulations(network, start_nodes, runs=1000):
"""Simulate the network repeatedly
:param network:
:param start_nodes:
:param runs:
:return:
"""
def ns_to_tuple(ns):
return tuple(-1 if ns[k] == 'F' else 1 for k in network.keys())
return np.array(
[ns_to_tuple(run_network(network, start_nodes))
for _ in range(runs)],
dtype=[(k, int) for k in network.keys()])
def compute_2x2(sim, na, nb, no=None):
"""Compute the 2x2 matrix for nodes A and B, with or without O being the observed node
:param sim:
:param na:
:param nb:
:param no:
:return:
"""
idx = (sim[no[0]] == no[1]) if no is not None else np.ones(len(sim[na]), dtype=bool)
return np.array(
[
[((sim[na][idx] == A) & (sim[nb][idx] == B)).sum() for B in [1, -1]]
for A in [1, -1]
]
)
def chi_square_statistic(x):
"""Given a 2x2 matrix, compute the chi-square statistic on it
:param x:
:return:
"""
expected_x = np.array([
[(float(x[:, 0].sum()) / x.sum()) * x[0, :].sum(),
(float(x[:, 1].sum()) / x.sum()) * x[0, :].sum()],
[(float(x[:, 0].sum()) / x.sum()) * x[1, :].sum(),
(float(x[:, 1].sum()) / x.sum()) * x[1, :].sum()],
])
return ((x - expected_x)**2 / expected_x).sum()
def phi(x):
"""Given a 2x2 matrix, compute the phi statistic on it
:param x:
:return:
"""
return (x[0, 0] * x[1, 1] - x[0, 1] * x[1, 0]) / (x[0, :].sum() * x[1, :].sum() * x[:, 0].sum() * x[:, 1].sum()) ** 0.5
def probe_nodes(sim, na, nb, no=None):
"""Given a simulation and two nodes (and possibly an observed node) return us the chi-square and
the chi-square statistic
If an observed variable (no) is supplied, return us two chi-square statistics that correspond to
when no='T' and no='F'
:param sim:
:param na:
:param nb:
:param no:
:return:
"""
chi2x2 = [
compute_2x2(sim, na, nb, _no)
for _no in ([(no, -1), (no, 1)] if no is not None else [None])
]
return chi2x2, [chi_square_statistic(ch) for ch in chi2x2]
def pretty_print_probe_nodes(sim, na, nb, nobs=None, method='phi'):
"""
:param sim:
:param na:
:param nb:
:param nobs:
:param method: One of 'chi', 'phi', 'fisher'
:return:
"""
chi2x2 = [
compute_2x2(sim, na, nb, _no)
for _no in ([(nobs, -1), (nobs, 1)] if nobs is not None else [None])
]
if method == 'chi':
stat = [chi_square_statistic(ch) for ch in chi2x2]
elif method == 'phi':
stat = [phi(ch) for ch in chi2x2]
elif method == 'fisher':
stat = [fisher_exact(ch)[1] for ch in chi2x2]
print('{} & {} {}: {}'.format(
na, nb, 'with {} observed'.format(nobs) if nobs is not None else '',
stat
))
class TestNetwork(unittest.TestCase):
def test_network_validation(self):
"""Network validation"""
with self.assertRaises(AssertionError):
validate_network({'A': {'parents': ['B']}, 'B': {'parents': ['A']}})
with self.assertRaises(AssertionError):
validate_network(OrderedDict({'A': {'parents': ['B']}, 'B': {'parents': ['A']}}))
with self.assertRaises(AssertionError):
validate_network(OrderedDict({'A': {'parents': ['B']}, 'B': {'parents': ['C']}}))
with self.assertRaises(AssertionError):
validate_network(OrderedDict({'A': {'parents': ['B'], 'table': {'X': 1.1}},
'B': {'parents': [], 'table': {'X': 0.2}}}))
validate_network(party_network()[0])
def three_node_head_to_tail_unrelated():
"""
A -> C -> B
:return:
"""
return OrderedDict({
'A': {
'parents': [],
'table': {
'X': 0.5
}
},
'C': {
'parents': ['A'],
'table': {
'T': 0.7,
'F': 0.3
}
},
'B': {
'parents': ['C'],
'table': {
'T': 0.5,
'F': 0.5
}
}
}), ['B']
def three_node_head_to_tail_related():
"""
A -> C -> B
:return:
"""
return OrderedDict({
'A': {
'parents': [],
'table': {
'X': 0.5
}
},
'C': {
'parents': ['A'],
'table': {
'T': 0.7,
'F': 0.3
}
},
'B': {
'parents': ['C'],
'table': {
'T': 0.7,
'F': 0.3
}
}
}), ['B']
def three_node_tail_to_tail():
"""
A <- C -> B
:return:
"""
return OrderedDict({
'A': {
'parents': ['C'],
'table': {
'T': 0.7,
'F': 0.3
}
},
'C': {
'parents': [],
'table': {
'X': 0.5
}
},
'B': {
'parents': ['C'],
'table': {
'T': 0.7,
'F': 0.3
}
}
}), ['A', 'B']
def three_node_head_to_head():
"""
A -> C <- B
:return:
"""
return OrderedDict({
'A': {
'parents': [],
'table': {
'X': 0.5
}
},
'C': {
'parents': ['A', 'B'],
'table': {
'FF': 0.1,
'FT': 0.7,
'TF': 0.3,
'TT': 0.9
}
},
'B': {
'parents': [],
'table': {
'X': 0.5
}
}
}), ['C']
def run_three_node_examples(runs=10000):
for netw in [
(three_node_head_to_tail_unrelated, 'A -> C -> B Unrelated'),
(three_node_head_to_tail_related, 'A -> C -> B Related'),
(three_node_tail_to_tail, 'A <- C -> B'),
(three_node_head_to_head, 'A -> C <- B')
]:
net, start_nodes = netw[0]()
sim = run_simulations(net, start_nodes, runs=runs)
print(netw[1])
pretty_print_probe_nodes(sim, 'A', 'B')
pretty_print_probe_nodes(sim, 'A', 'B', 'C')
def party_network():
"""This sets up a the 7 node kid's party network"""
return OrderedDict({
'E': {
'parents': [],
'table': {
'X': 0.1
}
},
'K': {
'parents': ['P'],
'table': {
'T': 0.7,
'F': 0.1
}
},
'P': {
'parents': [],
'table': {
'X': 0.1
}
},
'F': {
'parents': ['E', 'K'],
'table': {
'FF': 0.1,
'FT': 0.5,
'TF': 0.6,
'TT': 0.8
}
},
'D': {
'parents': ['P'],
'table': {
'T': 0.8,
'F': 0.4
}
},
'G': {
'parents': ['F', 'D'],
'table': {
'FF': 0.1,
'FT': 0.4,
'TF': 0.6,
'TT': 0.8
}
},
'C': {
'parents': ['G'],
'table': {
'T': 0.4,
'F': 0.4
}
}
}), ['C'] # returns network and the start_node list
def run_party_example(runs=100000):
print('Party network')
net, start_nodes = party_network()
sim = run_simulations(net, start_nodes, runs=runs)
pretty_print_probe_nodes(sim, 'K', 'D')
pretty_print_probe_nodes(sim, 'K', 'D', 'P')
pretty_print_probe_nodes(sim, 'F', 'D')
pretty_print_probe_nodes(sim, 'F', 'D', 'G')
pretty_print_probe_nodes(sim, 'F', 'D')
pretty_print_probe_nodes(sim, 'F', 'D', 'C')
pretty_print_probe_nodes(sim, 'P', 'G')
pretty_print_probe_nodes(sim, 'P', 'G', 'D')
pretty_print_probe_nodes(sim, 'E', 'G')
pretty_print_probe_nodes(sim, 'E', 'G', 'F')
if __name__ == '__main__':
# unittest.main()
run_three_node_examples(runs=100000)
run_party_example(runs=100000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment