Skip to content

Instantly share code, notes, and snippets.

@niczky12
Last active August 25, 2022 09:53
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save niczky12/587e601874c88d63cc6e808dae6e5a34 to your computer and use it in GitHub Desktop.
Save niczky12/587e601874c88d63cc6e808dae6e5a34 to your computer and use it in GitHub Desktop.
Permutation test for BigQuery tables
/*
Written by Bence Komarniczky
https://github.com/niczky12
https://niczky12.medium.com
Under MIT licence, free to use but no guarantees.
*/
CREATE OR REPLACE PROCEDURE ds.permutation_test(
table_name STRING
,col STRING
,flag_col STRING
,num_iter INT64
,OUT p FLOAT64)
BEGIN
DECLARE SQL_STRING STRING;
SET SQL_STRING = CONCAT(
"WITH original_table AS (",
-- rename columns and add row id
"SELECT ",
col || " as measurement",
",cast(row_number() over() as string) as row_id,",
flag_col || " as treatment ",
"from `" || table_name || "`",
"), permutations as (",
-- add permutation ids
"select cast(permutation_id as string) as permutation_id from (select generate_array(0," || cast(num_iter as string) || ") as p) cross join unnest(p) as permutation_id",
"), experiments as (",
-- cross join datasets and assign observations to the 2 groups randomly
"select permutation_id, measurement, mod(abs(farm_fingerprint(row_id || permutation_id)), 2) as treatment from original_table cross join permutations",
"), means as (",
-- measure mean differences for permutations
"select permutation_id",
",avg(case when treatment = 1 then measurement else null end) - avg(case when treatment = 0 then measurement else null end) as mean_diff",
" from experiments group by permutation_id",
"), observed as (",
-- record original observed difference
"select avg(case when treatment = 1 then measurement else null end) - avg(case when treatment = 0 then measurement else null end) as observed_diff from original_table",
"), comparisons as (",
-- compare permuted results to original results
"select permutation_id, mean_diff, observed_diff, case when mean_diff >= observed_diff then 1 else 0 end is_more_extreme from means, observed)",
-- calculate ratio of more extreme cases, this is our p-value
"select avg(is_more_extreme) as p_value from comparisons"
);
EXECUTE IMMEDIATE SQL_STRING INTO p;
END;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment