Instantly share code, notes, and snippets.

# bdatko/factorizing.py

Created July 1, 2023 19:07
Show Gist options
• Save bdatko/a3db959abbcfddc5a2266fa0473d4760 to your computer and use it in GitHub Desktop.
All Factorizing Joint Distributions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 from itertools import permutations def chain_rule(permutation): """Compute factorization for a given permutation of variables.""" n = len(permutation) if n == 1: return [f'P({permutation[0]})'] else: conditional = ','.join(permutation[1:]) return [f'P({permutation[0]}|{conditional})'] + chain_rule(permutation[1:]) def compute_factorizations(variables): """Compute all factorizations of the joint distribution of the given variables.""" factorizations = [] for perm in permutations(variables): factorizations.append(chain_rule(perm)) return factorizations # Test with n=3 variables = ['A', 'B', 'C'] for factoring in compute_factorizations(variables): factoring.reverse() print("".join(factoring)) # Output # P(C)P(B|C)P(A|B,C) # P(B)P(C|B)P(A|C,B) # P(C)P(A|C)P(B|A,C) # P(A)P(C|A)P(B|C,A) # P(B)P(A|B)P(C|A,B) # P(A)P(B|A)P(C|B,A)

### bdatko commented Jul 1, 2023

Obtained from ChatGPT-4 with the following prompts:

## Input Prompt

Can you write me a python program to list out all the ways to factorize a joint distribution of n variables?

1. Example n = 2
Input: A, B
Output: [P(A)P(B|A), P(B)P(A|B)]
2. Example n = 3
Input: A, B, C
Output: [ P(A)P(B|A)P(C|A,B), P(B)P(C|B)P(A|C,B), ... ]

## ChatGPT Output

Sure, to tackle this problem, we can use a recursive function which iterates over the permutations of the list of variables and, for each permutation, computes the factorization expression according to the chain rule of probability.
This Python script provides the required functionality:

from itertools import permutations

def chain_rule(permutation):
"""Compute factorization for a given permutation of variables."""
n = len(permutation)
if n == 1:
return f'P({permutation[0]})'
else:
conditional = ','.join(permutation[1:])
return f'P({permutation[0]}|{conditional})' + chain_rule(permutation[1:])

def compute_factorizations(variables):
"""Compute all factorizations of the joint distribution of the given variables."""
factorizations = []
for perm in permutations(variables):
factorizations.append(chain_rule(perm))
return factorizations

# Test with n=2
variables = ['A', 'B']
print(compute_factorizations(variables))

# Test with n=3
variables = ['A', 'B', 'C']
print(compute_factorizations(variables))

This script will print all the factorization forms for a joint distribution of 2 and 3 variables. Note that this solution does not consider independencies among variables and assumes that every variable might be dependent on all the others.
Also, this solution works with a reasonable number of variables. The number of permutations grows factorially with the number of variables, so for a large number of variables, it might be slow or even impossible to compute all the factorizations.
Let me know if this is what you were asking for, or if you need further help!

## Input Prompt

Can you modify such that the return is a list of the components? For example,
Input: 'P(A|B,C)P(B|C)P(C)'
Output: ['P(A|B,C)', 'P(B|C)', 'P(C)']

## ChatGPT Output

Sure, to modify the code to output lists of components instead of the entire expression as a string, we can adjust the chain_rule function to output a list of strings at each recursive call. Here is the adjusted code:

from itertools import permutations

def chain_rule(permutation):
"""Compute factorization for a given permutation of variables."""
n = len(permutation)
if n == 1:
return [f'P({permutation[0]})']
else:
conditional = ','.join(permutation[1:])
return [f'P({permutation[0]}|{conditional})'] + chain_rule(permutation[1:])

def compute_factorizations(variables):
"""Compute all factorizations of the joint distribution of the given variables."""
factorizations = []
for perm in permutations(variables):
factorizations.append(chain_rule(perm))
return factorizations

# Test with n=2
variables = ['A', 'B']
print(compute_factorizations(variables))

# Test with n=3
variables = ['A', 'B', 'C']
print(compute_factorizations(variables))

This version of the script will now print the factorization forms as lists of their components. For example, for the ['A', 'B', 'C'] input, it will output something like:

[['P(A|B,C)', 'P(B|C)', 'P(C)'], ['P(A|C,B)', 'P(C|B)', 'P(B)'], ['P(B|A,C)', 'P(A|C)', 'P(C)'], ...]