Skip to content

Instantly share code, notes, and snippets.

@rodonn
Last active December 9, 2023 21:28
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rodonn/63bdcb6787f1643b80581f7ff69a67f4 to your computer and use it in GitHub Desktop.
Save rodonn/63bdcb6787f1643b80581f7ff69a67f4 to your computer and use it in GitHub Desktop.
Causal DAG simulator
import graphviz as gr
import pandas as pd
def simulate(**kwargs):
values = {}
g = gr.Digraph()
for k,v in kwargs.items():
parents = v.__code__.co_varnames
inputs = {arg: values[arg] for arg in parents}
values[k] = v(**inputs)
for p in parents:
g.edge(p,k)
data = pd.DataFrame(values)
return data, g
# Alternative version that gives special treatment for the number of rows N
import graphviz as gr
import pandas as pd
def get_function_args(func):
return func.__code__.co_varnames[:func.__code__.co_argcount]
def simulate(N: int, **kwargs):
values = {}
g = gr.Digraph()
for variable_name, function in kwargs.items():
parents = get_function_args(function)
inputs = {arg: values[arg] for arg in parents if arg in values}
if 'N' in parents:
inputs['N'] = N
values[variable_name] = function(**inputs)
for p in parents:
g.edge(p, variable_name)
data = pd.DataFrame(values)
return data, g
# Example usage:
import numpy as np
from numpy.random import normal, uniform, choice
def get_income(age, height, gender, N):
return normal(100*age + 10*height, 1000 + np.where(gender=='male', 1000, 0), N)
df, g = simulate(
N = 100,
age=lambda N: uniform(0,100,N),
gender=lambda N: choice(['male', 'female'], N),
height=lambda age, N: normal(4.5, 1, N) + np.where(age > 15, 1, 0),
income = get_income,
)
df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment