Skip to content

Instantly share code, notes, and snippets.

@kghose
Last active April 13, 2017 20:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kghose/fb997a32150bbea6573256a2ebb6a2a4 to your computer and use it in GitHub Desktop.
Save kghose/fb997a32150bbea6573256a2ebb6a2a4 to your computer and use it in GitHub Desktop.
Simple simulator to explore conditional independence in Bayesian networks
"""Simple simulator to explore conditional independence in Bayesian networks
Node representation:
Say the node is has two parents named L and K and the probability table is
L | K | p(T)
--------------
F | F | 0.2
F | T | 0.4
T | F | 0.7
T | T | 0.9
The node is represented as the following dictionary
{
'parents': ['L', 'K'],
'table': {
'FF': 0.2,
'FT': 0.4,
'TF': 0.7,
'TT': 0.9
}
}
For a node with no parents we have
{
'parents': [],
'table': {
'X': 0.2,
}
}
A network is represented as a dictionary with keys being the node names and values being the
nodes represented as the dictionaries described above
So a simple two element network A -> B with probability tables
A p(T) = 0.4
A | B p(T)
------------
T | 0.2
F | 0.6
Is represented as
{
'A': {
'parents': [],
'table': {
'X': 0.4
}
},
'B': {
'parents': ['A'],
'table': {
'T': 0.2,
'F': 0.6
}
}
}
"""
import unittest
from collections import OrderedDict
import numpy as np
from scipy.stats import fisher_exact
def validate_network(net):
"""Simple sanity checks on the network"""
assert isinstance(net, OrderedDict), 'Network must be OrderedDict'
for k, v in net.items():
if len(v['parents']):
for p in v['parents']:
assert p in net, 'Node {} has parent {} not in network'.format(k, p)
for k, v in net.items():
cycle_check(k, k, net)
for k, v in net.items():
for p in v['table'].values():
assert 0 <= p <= 1.0, 'Node {} has out of range probability in table'.format(k)
def cycle_check(descendant, ancestor, network):
"""Given the original descendant and an ancestor, recursively go back along the network to
see if we ever run into the ancestor again
:param descendant:
:param ancestor:
:param network:
:return:
"""
for p in network[ancestor]['parents']:
if p == descendant:
raise AssertionError('Cycle in given network')
else:
cycle_check(descendant, p, network)
def get_p_true(node, inputs=None):
"""Given a node and inputs, compute the probbability of the node going to True state
:param node:
:param inputs: If the node parents are L and K, the inputs should look like
{'L': 'T', 'K': 'F'} and so on
:return:
"""
key = ''.join(inputs[k] for k in node['parents']) if len(node['parents']) > 0 else 'X'
return node['table'][key]
def get_node_value(node_name, network, network_state):
"""Given a network compute the value of the given node, recursively computing
the values of parent nodes as needed
:param node_name
:param network:
:param network_state: A dictionary with keys as node names and values as node states
:param r is an array of random numbers same size as the network node count that determines
node state
:return:
"""
if node_name not in network_state:
this_node = network[node_name]
pt = get_p_true(
this_node,
inputs={p: get_node_value(p, network, network_state)
for p in this_node['parents']}
)
network_state[node_name] = 'T' if np.random.rand() < pt else 'F'
return network_state[node_name]
def run_network(network, start_nodes):
"""Simulate the network
:param network:
:param start_nodes: nodes which when computed, will result in the computation of all other nodes
:return: a dictionary of {node name: node state}
"""
network_state = {}
for node in start_nodes:
get_node_value(node, network, network_state)
return network_state
def run_simulations(network, start_nodes, runs=1000):
"""Simulate the network repeatedly
:param network:
:param start_nodes:
:param runs:
:return:
"""
def ns_to_tuple(ns):
return tuple(-1 if ns[k] == 'F' else 1 for k in network.keys())
return np.array(
[ns_to_tuple(run_network(network, start_nodes))
for _ in range(runs)],
dtype=[(k, int) for k in network.keys()])
def compute_2x2(sim, na, nb, no=None):
"""Compute the 2x2 matrix for nodes A and B, with or without O being the observed node
:param sim:
:param na:
:param nb:
:param no:
:return:
"""
idx = (sim[no[0]] == no[1]) if no is not None else np.ones(len(sim[na]), dtype=bool)
return np.array(
[
[((sim[na][idx] == A) & (sim[nb][idx] == B)).sum() for B in [1, -1]]
for A in [1, -1]
]
)
def chi_square_statistic(x):
"""Given a 2x2 matrix, compute the chi-square statistic on it
:param x:
:return:
"""
expected_x = np.array([
[(float(x[:, 0].sum()) / x.sum()) * x[0, :].sum(),
(float(x[:, 1].sum()) / x.sum()) * x[0, :].sum()],
[(float(x[:, 0].sum()) / x.sum()) * x[1, :].sum(),
(float(x[:, 1].sum()) / x.sum()) * x[1, :].sum()],
])
return ((x - expected_x)**2 / expected_x).sum()
def phi(x):
"""Given a 2x2 matrix, compute the phi statistic on it
:param x:
:return:
"""
return (x[0, 0] * x[1, 1] - x[0, 1] * x[1, 0]) / (x[0, :].sum() * x[1, :].sum() * x[:, 0].sum() * x[:, 1].sum()) ** 0.5
def probe_nodes(sim, na, nb, no=None):
"""Given a simulation and two nodes (and possibly an observed node) return us the chi-square and
the chi-square statistic
If an observed variable (no) is supplied, return us two chi-square statistics that correspond to
when no='T' and no='F'
:param sim:
:param na:
:param nb:
:param no:
:return:
"""
chi2x2 = [
compute_2x2(sim, na, nb, _no)
for _no in ([(no, -1), (no, 1)] if no is not None else [None])
]
return chi2x2, [chi_square_statistic(ch) for ch in chi2x2]
def pretty_print_probe_nodes(sim, na, nb, nobs=None, method='phi'):
"""
:param sim:
:param na:
:param nb:
:param nobs:
:param method: One of 'chi', 'phi', 'fisher'
:return:
"""
chi2x2 = [
compute_2x2(sim, na, nb, _no)
for _no in ([(nobs, -1), (nobs, 1)] if nobs is not None else [None])
]
if method == 'chi':
stat = [chi_square_statistic(ch) for ch in chi2x2]
elif method == 'phi':
stat = [phi(ch) for ch in chi2x2]
elif method == 'fisher':
stat = [fisher_exact(ch)[1] for ch in chi2x2]
print('{} & {} {}: {}'.format(
na, nb, 'with {} observed'.format(nobs) if nobs is not None else '',
stat
))
class TestNetwork(unittest.TestCase):
def test_network_validation(self):
"""Network validation"""
with self.assertRaises(AssertionError):
validate_network({'A': {'parents': ['B']}, 'B': {'parents': ['A']}})
with self.assertRaises(AssertionError):
validate_network(OrderedDict({'A': {'parents': ['B']}, 'B': {'parents': ['A']}}))
with self.assertRaises(AssertionError):
validate_network(OrderedDict({'A': {'parents': ['B']}, 'B': {'parents': ['C']}}))
with self.assertRaises(AssertionError):
validate_network(OrderedDict({'A': {'parents': ['B'], 'table': {'X': 1.1}},
'B': {'parents': [], 'table': {'X': 0.2}}}))
validate_network(party_network()[0])
def three_node_head_to_tail_unrelated():
"""
A -> C -> B
:return:
"""
return OrderedDict({
'A': {
'parents': [],
'table': {
'X': 0.5
}
},
'C': {
'parents': ['A'],
'table': {
'T': 0.7,
'F': 0.3
}
},
'B': {
'parents': ['C'],
'table': {
'T': 0.5,
'F': 0.5
}
}
}), ['B']
def three_node_head_to_tail_related():
"""
A -> C -> B
:return:
"""
return OrderedDict({
'A': {
'parents': [],
'table': {
'X': 0.5
}
},
'C': {
'parents': ['A'],
'table': {
'T': 0.7,
'F': 0.3
}
},
'B': {
'parents': ['C'],
'table': {
'T': 0.7,
'F': 0.3
}
}
}), ['B']
def three_node_tail_to_tail():
"""
A <- C -> B
:return:
"""
return OrderedDict({
'A': {
'parents': ['C'],
'table': {
'T': 0.7,
'F': 0.3
}
},
'C': {
'parents': [],
'table': {
'X': 0.5
}
},
'B': {
'parents': ['C'],
'table': {
'T': 0.7,
'F': 0.3
}
}
}), ['A', 'B']
def three_node_head_to_head():
"""
A -> C <- B
:return:
"""
return OrderedDict({
'A': {
'parents': [],
'table': {
'X': 0.5
}
},
'C': {
'parents': ['A', 'B'],
'table': {
'FF': 0.1,
'FT': 0.7,
'TF': 0.3,
'TT': 0.9
}
},
'B': {
'parents': [],
'table': {
'X': 0.5
}
}
}), ['C']
def run_three_node_examples(runs=10000):
for netw in [
(three_node_head_to_tail_unrelated, 'A -> C -> B Unrelated'),
(three_node_head_to_tail_related, 'A -> C -> B Related'),
(three_node_tail_to_tail, 'A <- C -> B'),
(three_node_head_to_head, 'A -> C <- B')
]:
net, start_nodes = netw[0]()
sim = run_simulations(net, start_nodes, runs=runs)
print(netw[1])
pretty_print_probe_nodes(sim, 'A', 'B')
pretty_print_probe_nodes(sim, 'A', 'B', 'C')
def party_network():
"""This sets up a the 7 node kid's party network"""
return OrderedDict({
'E': {
'parents': [],
'table': {
'X': 0.1
}
},
'K': {
'parents': ['P'],
'table': {
'T': 0.7,
'F': 0.1
}
},
'P': {
'parents': [],
'table': {
'X': 0.1
}
},
'F': {
'parents': ['E', 'K'],
'table': {
'FF': 0.1,
'FT': 0.5,
'TF': 0.6,
'TT': 0.8
}
},
'D': {
'parents': ['P'],
'table': {
'T': 0.8,
'F': 0.4
}
},
'G': {
'parents': ['F', 'D'],
'table': {
'FF': 0.1,
'FT': 0.4,
'TF': 0.6,
'TT': 0.8
}
},
'C': {
'parents': ['G'],
'table': {
'T': 0.4,
'F': 0.4
}
}
}), ['C'] # returns network and the start_node list
def run_party_example(runs=100000):
print('Party network')
net, start_nodes = party_network()
sim = run_simulations(net, start_nodes, runs=runs)
pretty_print_probe_nodes(sim, 'K', 'D')
pretty_print_probe_nodes(sim, 'K', 'D', 'P')
pretty_print_probe_nodes(sim, 'F', 'D')
pretty_print_probe_nodes(sim, 'F', 'D', 'G')
pretty_print_probe_nodes(sim, 'F', 'D')
pretty_print_probe_nodes(sim, 'F', 'D', 'C')
pretty_print_probe_nodes(sim, 'P', 'G')
pretty_print_probe_nodes(sim, 'P', 'G', 'D')
pretty_print_probe_nodes(sim, 'E', 'G')
pretty_print_probe_nodes(sim, 'E', 'G', 'F')
if __name__ == '__main__':
# unittest.main()
run_three_node_examples(runs=100000)
run_party_example(runs=100000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment