Last active
October 21, 2019 03:11
-
-
Save NaimKabir/4f1cf7759054ddb270cebf45d0ad210c to your computer and use it in GitHub Desktop.
Showing how to find a max ATE by optimizing an LP problem
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Linear optimization\n", | |
"\n", | |
"First we'll list out all the probabilities we need: p(x,y|z) for each x,y,z combination." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>P( X, Y | Z )</th>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>state</th>\n", | |
" <th></th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>control/untreated/bad</th>\n", | |
" <td>0.315285</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>control/untreated/good</th>\n", | |
" <td>0.043756</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>control/treated/bad</th>\n", | |
" <td>0.323676</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>control/treated/good</th>\n", | |
" <td>0.317283</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>treatment/untreated/bad</th>\n", | |
" <td>0.019820</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>treatment/untreated/good</th>\n", | |
" <td>0.670671</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>treatment/treated/bad</th>\n", | |
" <td>0.167968</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>treatment/treated/good</th>\n", | |
" <td>0.141542</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" P( X, Y | Z )\n", | |
"state \n", | |
"control/untreated/bad 0.315285\n", | |
"control/untreated/good 0.043756\n", | |
"control/treated/bad 0.323676\n", | |
"control/treated/good 0.317283\n", | |
"treatment/untreated/bad 0.019820\n", | |
"treatment/untreated/good 0.670671\n", | |
"treatment/treated/bad 0.167968\n", | |
"treatment/treated/good 0.141542" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"p_states" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now we can frame our linear programming problem.\n", | |
"\n", | |
"We're maximizing our expression of an ATE given the constraints at hand." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pulp\n", | |
"\n", | |
"# first we set up 'problems', objects that\n", | |
"# can take an objective like 'maximize' as well as constraints\n", | |
"\n", | |
"max_problem = pulp.LpProblem(\n", | |
" \"max ATE\", pulp.LpMaximize\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can now specify variables, symbols that the linear program will know to vary in order to maximize or minimize our objective function. These variables are the distribution of **U**: the probabilities of belonging to one of the archetypal (Compliance, Response) partitions." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Linear optimization ranges over possible variables\n", | |
"# to find those hidden variables that maximize\n", | |
"# or minimize our objective.\n", | |
"\n", | |
"# Our variables 'q' are the the probability of\n", | |
"# being in one of the partitions\n", | |
"\n", | |
"partition_names = [\n", | |
" \"/\".join([compliance, response])\n", | |
" for compliance, response in partition_types\n", | |
"]\n", | |
"\n", | |
"# Notice that we put a lower bound of 0 on each\n", | |
"# p(partition) value.\n", | |
"# This is because probability values cannot be negative.\n", | |
"q = {\n", | |
" partition: pulp.LpVariable(partition, lowBound=0)\n", | |
" for partition in partition_names\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can add constraints on these variables as well, by simply adding them to the problem we've framed!\n", | |
"\n", | |
"An obvious constraint: the probabilities of being in each partition must sum up to 1:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Since our vars are probabilities\n", | |
"# the sum of them should all be under 1.\n", | |
"\n", | |
"# This '+=' operation is adding this sum constraint\n", | |
"# to the linear programming problem.\n", | |
"\n", | |
"max_problem += sum([v for k, v in q.items()]) == 1" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The real key constraint to add: the probabilities of being in a paritition are linked directly to the experimental data we observe:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Let's set the relationship between the distribution of U\n", | |
"# and the probabilities we observe. The variable name\n", | |
"# scheme is p_Z_X_Y\n", | |
"\n", | |
"\n", | |
"p_treatment_untreated_bad = (\n", | |
" q[\"never_taker/never_better\"]\n", | |
" + q[\"defier/never_better\"]\n", | |
" + q[\"never_taker/helped\"]\n", | |
" + q[\"defier/helped\"]\n", | |
")\n", | |
"\n", | |
"p_treatment_untreated_good = (\n", | |
" q[\"never_taker/always_better\"]\n", | |
" + q[\"defier/always_better\"]\n", | |
" + q[\"never_taker/hurt\"]\n", | |
" + q[\"defier/hurt\"]\n", | |
")\n", | |
"\n", | |
"p_treatment_treated_bad = (\n", | |
" q[\"always_taker/never_better\"]\n", | |
" + q[\"complier/never_better\"]\n", | |
" + q[\"always_taker/hurt\"]\n", | |
" + q[\"complier/hurt\"]\n", | |
")\n", | |
"\n", | |
"p_treatment_treated_good = (\n", | |
" q[\"always_taker/always_better\"]\n", | |
" + q[\"complier/always_better\"]\n", | |
" + q[\"always_taker/helped\"]\n", | |
" + q[\"complier/helped\"]\n", | |
")\n", | |
"\n", | |
"p_control_untreated_bad = (\n", | |
" q[\"never_taker/never_better\"]\n", | |
" + q[\"complier/never_better\"]\n", | |
" + q[\"never_taker/helped\"]\n", | |
" + q[\"complier/helped\"]\n", | |
")\n", | |
"\n", | |
"p_control_untreated_good = (\n", | |
" q[\"never_taker/always_better\"]\n", | |
" + q[\"complier/never_better\"]\n", | |
" + q[\"never_taker/hurt\"]\n", | |
" + q[\"complier/hurt\"]\n", | |
")\n", | |
"\n", | |
"p_control_treated_bad = (\n", | |
" q[\"always_taker/never_better\"]\n", | |
" + q[\"defier/never_better\"]\n", | |
" + q[\"always_taker/hurt\"]\n", | |
" + q[\"defier/hurt\"]\n", | |
")\n", | |
"\n", | |
"p_control_treated_good = (\n", | |
" q[\"always_taker/always_better\"]\n", | |
" + q[\"defier/always_better\"]\n", | |
" + q[\"always_taker/helped\"]\n", | |
" + q[\"defier/helped\"]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# I now apply these linear relationships between U\n", | |
"# and P(X,Y|Z) as constraints on the LP problem.\n", | |
"\n", | |
"# These '+=' operations are new constraints I am \n", | |
"# adding one-by-one to the 'max_problem' linear programming \n", | |
"# problem.\n", | |
"\n", | |
"max_problem += (\n", | |
" p_treatment_untreated_bad\n", | |
" == p_states.loc[\"treatment/untreated/bad\"]\n", | |
")\n", | |
"max_problem += (\n", | |
" p_treatment_treated_good\n", | |
" == p_states.loc[\"treatment/treated/good\"]\n", | |
")\n", | |
"max_problem += (\n", | |
" p_treatment_treated_bad\n", | |
" == p_states.loc[\"treatment/treated/bad\"]\n", | |
")\n", | |
"\n", | |
"max_problem += (\n", | |
" p_control_untreated_bad\n", | |
" == p_states.loc[\"control/untreated/bad\"]\n", | |
")\n", | |
"max_problem += (\n", | |
" p_control_treated_good\n", | |
" == p_states.loc[\"control/treated/good\"]\n", | |
")\n", | |
"max_problem += (\n", | |
" p_control_treated_bad\n", | |
" == p_states.loc[\"control/treated/bad\"]\n", | |
")\n", | |
"\n", | |
"\n", | |
"# I leave some constraints out because it *over*constrains\n", | |
"# the problem and makes it impossible to solve\n", | |
"# It turns out that the other constraints actually imply\n", | |
"# these two since they are complimentary probabilities,\n", | |
"# so we can just leave them commented out\n", | |
"\n", | |
"# max_problem += p_control_untreated_good == p_states.loc['control/untreated/good']\n", | |
"# max_problem += p_treatment_untreated_good == p_states.loc['treatment/untreated/good']" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"With constraints set, all that remains to do set the objective function, our ATE.\n", | |
"\n", | |
"$$ATE = P(helped) - P(hurt)$$\n", | |
"\n", | |
"where $P(helped)$ is the probability of being in a partition with a 'helped' response type and $P(hurt)$ is the probability of being in a partition with a 'hurt' response type.\n", | |
"\n", | |
"Then we can just hit `solve()`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"max_problem += (\n", | |
" q[\"complier/helped\"]\n", | |
" + q[\"defier/helped\"]\n", | |
" + q[\"always_taker/helped\"]\n", | |
" + q[\"never_taker/helped\"]\n", | |
" - q[\"complier/hurt\"]\n", | |
" - q[\"defier/hurt\"]\n", | |
" - q[\"always_taker/hurt\"]\n", | |
" - q[\"never_taker/hurt\"]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'Optimal'" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pulp.LpStatus[max_problem.solve()]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The problem's been optimized, and now we can see what U looks like in the best case scenario." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>p(U=u)</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>always_taker/always_better</th>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>always_taker/helped</th>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>always_taker/hurt</th>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>always_taker/never_better</th>\n", | |
" <td>0.014045</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>complier/always_better</th>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>complier/helped</th>\n", | |
" <td>0.141542</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>complier/hurt</th>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>complier/never_better</th>\n", | |
" <td>0.153923</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>defier/always_better</th>\n", | |
" <td>0.317283</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>defier/helped</th>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>defier/hurt</th>\n", | |
" <td>0.309632</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>defier/never_better</th>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>never_taker/always_better</th>\n", | |
" <td>0.043756</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>never_taker/helped</th>\n", | |
" <td>0.019820</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>never_taker/hurt</th>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>never_taker/never_better</th>\n", | |
" <td>0.000000</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" p(U=u)\n", | |
"always_taker/always_better 0.000000\n", | |
"always_taker/helped 0.000000\n", | |
"always_taker/hurt 0.000000\n", | |
"always_taker/never_better 0.014045\n", | |
"complier/always_better 0.000000\n", | |
"complier/helped 0.141542\n", | |
"complier/hurt 0.000000\n", | |
"complier/never_better 0.153923\n", | |
"defier/always_better 0.317283\n", | |
"defier/helped 0.000000\n", | |
"defier/hurt 0.309632\n", | |
"defier/never_better 0.000000\n", | |
"never_taker/always_better 0.043756\n", | |
"never_taker/helped 0.019820\n", | |
"never_taker/hurt 0.000000\n", | |
"never_taker/never_better 0.000000" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"q_max = {\n", | |
" partition: pulp.value(partition_p)\n", | |
" for partition, partition_p in q.items()\n", | |
"}\n", | |
"\n", | |
"# display\n", | |
"\n", | |
"pd.DataFrame(q_max, index=[\"p(U=u)\"]).T" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Given these we can calculate our best-case ATE." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Best-Case ATE: -0.1483\n" | |
] | |
} | |
], | |
"source": [ | |
"ate = (\n", | |
" lambda q: q[\"complier/helped\"]\n", | |
" + q[\"defier/helped\"]\n", | |
" + q[\"always_taker/helped\"]\n", | |
" + q[\"never_taker/helped\"]\n", | |
" - q[\"complier/hurt\"]\n", | |
" - q[\"defier/hurt\"]\n", | |
" - q[\"always_taker/hurt\"]\n", | |
" - q[\"never_taker/hurt\"]\n", | |
")\n", | |
"\n", | |
"best_case_ate = ate(q_max)\n", | |
"\n", | |
"# display\n", | |
"print(\"Best-Case ATE: %6.4f\" % best_case_ate)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment