Created
December 8, 2023 21:09
-
-
Save anatolec/8ef6b08b97cc00a8ca2a4625357cbf98 to your computer and use it in GitHub Desktop.
A BigQuery UDF to compute ROC AUC given 2 arrays of scores and labels
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CREATE OR REPLACE FUNCTION `<project-id>.<dataset-id>.roc_auc`(scores ARRAY<FLOAT64>, labels ARRAY<BOOL>) | |
RETURNS FLOAT64 AS ( | |
( | |
with input_data as ( | |
select scores.score, labels.label | |
from( | |
SELECT *, row_number() over () as row_num | |
FROM UNNEST(scores) as score) scores | |
join( | |
SELECT *, row_number() over () as row_num | |
FROM UNNEST(labels) as label) labels | |
on scores.row_num=labels.row_num | |
), | |
distinct_predictions as ( | |
select distinct score as threshold, | |
from input_data | |
), | |
roc_curve as ( | |
select | |
threshold, | |
safe_divide(countif(score >= threshold and not label), countif(not label)) as fpr, | |
safe_divide(countif(score >= threshold and label), countif(label)) as tpr | |
from input_data | |
cross join distinct_predictions | |
group by threshold | |
order by fpr | |
), | |
auc as ( | |
select | |
fpr, | |
tpr, | |
lag(fpr) over (order by fpr) as next_fpr, | |
lag(tpr) over (order by tpr) as next_tpr, | |
(fpr - lag(fpr) over (order by fpr)) * (tpr + lag(tpr) over (order by tpr)) / 2 as roc_auc_contrib | |
from ( | |
select * | |
from roc_curve | |
union all | |
select max(threshold) + 1, 0, 0 from roc_curve) | |
order by fpr | |
) | |
select sum(roc_auc_contrib) as roc_auc from auc) | |
); | |
-- Sample usage : | |
WITH predictions AS ( | |
SELECT | |
type, | |
score, | |
label, | |
FROM transactions | |
) | |
SELECT | |
type, | |
<project-id>.roc_auc(array_agg(score), array_agg(label)) | |
FROM predictions | |
GROUP BY type |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment