Skip to content

Instantly share code, notes, and snippets.

@mattsgithub
Last active December 4, 2019 18:58
Show Gist options
  • Save mattsgithub/8dfe22cf849dfd7d184c9a58f1043411 to your computer and use it in GitHub Desktop.
Save mattsgithub/8dfe22cf849dfd7d184c9a58f1043411 to your computer and use it in GitHub Desktop.
Prepares Pandas DataFrame for Counterfactual Predictions
def get_pandas_df_for_counterfactual_prediction(df,
df_treatments,
id_column='customer_id',
suffix='_assigned'):
"""Returns a pandas dataframe prepared for counterfactual predictions
Args
df (pd.DataFrame):
The dataframe from which to prepare counterfactuals for
df_treatments (pd.DataFrame)
A dataframe of treatments. In experiments, we might
assign several different variable values. This treatment
dataframe contains all possible assignments.
id_column (str):
The column name in df that repesents the unit we are doing
analysis on.
suffix (str):
Treatment columns in df will be renamed with this suffix. It's
useful in retaining the original treatment columns for further
analysis
Examples:
Suppose df is:
customer_id t1 t2
1 1 0.1
2 3 0.2
Suppose df_treatments is:
t1 t2
1 0.1
1 0.2
3 0.1
3 0.2
>>> get_pandas_df_for_counterfactual_prediction(df, df_treatments)
Returns:
customer_id t1_assigned t2_assigned t1 t2
1 1 0.1 1 0.1
1 1 0.1 1 0.2
1 1 0.1 3 0.1
1 1 0.1 3 0.2
2 3 0.2 1 0.1
2 3 0.2 1 0.2
2 3 0.2 3 0.1
2 3 0.2 3 0.2
"""
intersecting_cols = df_treatments.columns.intersection(df.columns)
cols_not_found_in_df = df_treatments.columns.difference(intersecting_cols)
if len(cols_not_found_in_df) > 0:
raise ValueError('The columns {} in df_treatments '
'are not found in df'.format(cols_not_found_in_df))
n_rows = df.shape[0]
df = df.rename(columns={c: c + suffix for c in df_treatments.columns})
# Generate a row for each counterfactual prediction
df = df.append([df] * (df_treatments.shape[0] - 1))
df = df.sort_values(id_column)
df = df.reset_index(drop=True)
df_treatments = pd.concat([df_treatments] * n_rows)
df_treatments = df_treatments.reset_index(drop=True)
df = pd.concat([df, df_treatments], axis= 1)
return df
@mattsgithub
Copy link
Author

After having trained a heterogeneous treatment effects model, you will want to prepare a dataframe for counterfactual predictions by enumerating over all the possible combinations of treatments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment