Skip to content

Instantly share code, notes, and snippets.

@maropu
Created February 7, 2020 00:01
Show Gist options
  • Save maropu/fb6c0501469794d5e5f111d8e40f6d0b to your computer and use it in GitHub Desktop.
Save maropu/fb6c0501469794d5e5f111d8e40f6d0b to your computer and use it in GitHub Desktop.
# https://qiita.com/9_ties/items/3bdb177384937ddc88df
# https://homes.cs.washington.edu/~pedrod/papers/mlj05.pdf
import pandas as pd
import numpy as np
from scipy.special import logsumexp
from itertools import product
const = ['A', 'B']
preds = [('Smokes', 1), ('Cancer', 1), ('Friends', 2)] # Predicate and arity
ground_atoms = [
(p, *args)
for p, arity in preds
for args in product(const, repeat=arity)
]
print('=== Ground Atoms ===')
print(ground_atoms)
formulas = [
# (atom, negation, arity, weight)
([('Smokes', (0,)), ('Cancer', (0,))], [1, 0], 1, 1.5),
([('Friends', (0,1)), ('Smokes', (0,)), ('Smokes', (1,))], [1, 0, 1], 2, 1.1),
([('Friends', (0,1)), ('Smokes', (0,)), ('Smokes', (1,))], [1, 1, 0], 2, 1.1)
]
ground_formulas = []
for clauses, neg, arity, w in formulas:
for args in product(const, repeat=arity):
ground_formula = [
(p, *map(lambda i: args[i], v))
for p, v in clauses
]
ground_formulas.append((ground_formula, neg, w))
print('=== Ground Formulas ===')
print(ground_formulas)
# Generate all configurations
X = pd.DataFrame(columns=ground_atoms, data=list(product([1, 0], repeat=len(ground_atoms))))
# Compute sum_i(w_i*n_i(x))
S = np.zeros(len(X))
for f, neg, w in ground_formulas:
S += w * np.logical_xor(X[f], neg).any(1)
# Compute partition function
logZ = logsumexp(S)
# Compute joint probabilities
jointP = X.copy()
jointP['logP'] = S - logZ
print('=== Joint Probability ===')
print(jointP)
# Examples
print('=== P(Friends(A, B)) ===')
P_FrAB = np.exp(jointP.groupby([('Friends', 'A', 'B')])['logP'].agg(logsumexp))
print(P_FrAB)
print('=== P(Friends(A, B)|Smokes(A)) ===')
P_FrAB_SmoA = np.exp(jointP.groupby([('Smokes', 'A'), ('Friends', 'A', 'B')])['logP'].agg(logsumexp))
P_SmoA = np.exp(jointP.groupby([('Smokes', 'A')])['logP'].agg(logsumexp))
print(P_FrAB_SmoA/P_SmoA)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment