Last active
August 25, 2022 09:53
-
-
Save niczky12/587e601874c88d63cc6e808dae6e5a34 to your computer and use it in GitHub Desktop.
Permutation test for BigQuery tables
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
Written by Bence Komarniczky | |
https://github.com/niczky12 | |
https://niczky12.medium.com | |
Under MIT licence, free to use but no guarantees. | |
*/ | |
CREATE OR REPLACE PROCEDURE ds.permutation_test( | |
table_name STRING | |
,col STRING | |
,flag_col STRING | |
,num_iter INT64 | |
,OUT p FLOAT64) | |
BEGIN | |
DECLARE SQL_STRING STRING; | |
SET SQL_STRING = CONCAT( | |
"WITH original_table AS (", | |
-- rename columns and add row id | |
"SELECT ", | |
col || " as measurement", | |
",cast(row_number() over() as string) as row_id,", | |
flag_col || " as treatment ", | |
"from `" || table_name || "`", | |
"), permutations as (", | |
-- add permutation ids | |
"select cast(permutation_id as string) as permutation_id from (select generate_array(0," || cast(num_iter as string) || ") as p) cross join unnest(p) as permutation_id", | |
"), experiments as (", | |
-- cross join datasets and assign observations to the 2 groups randomly | |
"select permutation_id, measurement, mod(abs(farm_fingerprint(row_id || permutation_id)), 2) as treatment from original_table cross join permutations", | |
"), means as (", | |
-- measure mean differences for permutations | |
"select permutation_id", | |
",avg(case when treatment = 1 then measurement else null end) - avg(case when treatment = 0 then measurement else null end) as mean_diff", | |
" from experiments group by permutation_id", | |
"), observed as (", | |
-- record original observed difference | |
"select avg(case when treatment = 1 then measurement else null end) - avg(case when treatment = 0 then measurement else null end) as observed_diff from original_table", | |
"), comparisons as (", | |
-- compare permuted results to original results | |
"select permutation_id, mean_diff, observed_diff, case when mean_diff >= observed_diff then 1 else 0 end is_more_extreme from means, observed)", | |
-- calculate ratio of more extreme cases, this is our p-value | |
"select avg(is_more_extreme) as p_value from comparisons" | |
); | |
EXECUTE IMMEDIATE SQL_STRING INTO p; | |
END; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment