Calculate multiclass micro/macro precision and recall in SQL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
with agg as ( | |
select | |
p.label, p.probability, t.survived | |
from | |
prediction_rf p | |
join | |
test t on (p.rowid = t.rowid) | |
order by | |
probability desc | |
), | |
pred_sum as ( | |
select | |
'pred_sum' as key | |
, label as label | |
, count(1) as cnt | |
from | |
agg | |
group by label | |
order by label | |
), | |
true_sum as ( | |
select | |
'true_sum' as key | |
, survived as label | |
, count(1) as cnt | |
from | |
agg | |
group by survived | |
order by survived | |
), | |
true_positive as ( | |
select | |
'true_positive' as key | |
, label as label | |
, sum(if(label = survived, 1, 0)) as cnt | |
from | |
agg | |
group by label | |
order by label | |
), | |
agg_array as ( | |
select | |
key, array_agg(cnt) as cnt_ary | |
from | |
( | |
select | |
key, label, cnt | |
from | |
pred_sum | |
union all | |
select | |
key, label, cnt | |
from | |
true_sum | |
union all | |
select | |
key, label, cnt | |
from | |
true_positive | |
) t | |
group by key | |
) | |
select | |
cast(true_positive as double) / true_sum as recall_micro | |
, cast(true_positive as double) / pred_sum as precision_micro | |
, reduce(zip_with(true_positive_ary, true_sum_ary, (tp, ts) -> cast(tp as double)/ts), | |
CAST(ROW(0.0, 0) AS ROW(sum DOUBLE, count INTEGER)), | |
(s, x) -> CAST(ROW(x + s.sum, s.count + 1) AS ROW(sum DOUBLE, count INTEGER)), | |
s -> IF(s.count = 0, NULL, s.sum / s.count)) as recall_macro | |
, reduce(zip_with(true_positive_ary, pred_sum_ary, (tp, ps) -> cast(tp as double)/ps), | |
CAST(ROW(0.0, 0) AS ROW(sum DOUBLE, count INTEGER)), | |
(s, x) -> CAST(ROW(x + s.sum, s.count + 1) AS ROW(sum DOUBLE, count INTEGER)), | |
s -> IF(s.count = 0, NULL, s.sum / s.count)) as precision_macro | |
from ( | |
select | |
max(case when key = 'true_positive' then reduce(cnt_ary, 0, (s, x) -> s + x, s -> s) end) as true_positive | |
, max(case when key = 'true_sum' then reduce(cnt_ary, 0, (s, x) -> s + x, s -> s) end) as true_sum | |
, max(case when key = 'pred_sum' then reduce(cnt_ary, 0, (s, x) -> s + x, s -> s) end) as pred_sum | |
, max(case when key = 'true_positive' then cnt_ary end) as true_positive_ary | |
, max(case when key = 'true_sum' then cnt_ary end) as true_sum_ary | |
, max(case when key = 'pred_sum' then cnt_ary end) as pred_sum_ary | |
from | |
agg_array | |
group by dummy | |
) t2 | |
; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment