Skip to content

Instantly share code, notes, and snippets.

@chezou
Last active March 15, 2019 06:26
Show Gist options
  • Save chezou/b348b951e25442cec4cdf0390e7bb2a6 to your computer and use it in GitHub Desktop.
Save chezou/b348b951e25442cec4cdf0390e7bb2a6 to your computer and use it in GitHub Desktop.
Calculate multiclass micro/macro precision and recall in SQL
with agg as (
select
p.label, p.probability, t.survived
from
prediction_rf p
join
test t on (p.rowid = t.rowid)
order by
probability desc
),
pred_sum as (
select
'pred_sum' as key
, label as label
, count(1) as cnt
from
agg
group by label
order by label
),
true_sum as (
select
'true_sum' as key
, survived as label
, count(1) as cnt
from
agg
group by survived
order by survived
),
true_positive as (
select
'true_positive' as key
, label as label
, sum(if(label = survived, 1, 0)) as cnt
from
agg
group by label
order by label
),
agg_array as (
select
key, array_agg(cnt) as cnt_ary
from
(
select
key, label, cnt
from
pred_sum
union all
select
key, label, cnt
from
true_sum
union all
select
key, label, cnt
from
true_positive
) t
group by key
)
select
cast(true_positive as double) / true_sum as recall_micro
, cast(true_positive as double) / pred_sum as precision_micro
, reduce(zip_with(true_positive_ary, true_sum_ary, (tp, ts) -> cast(tp as double)/ts),
CAST(ROW(0.0, 0) AS ROW(sum DOUBLE, count INTEGER)),
(s, x) -> CAST(ROW(x + s.sum, s.count + 1) AS ROW(sum DOUBLE, count INTEGER)),
s -> IF(s.count = 0, NULL, s.sum / s.count)) as recall_macro
, reduce(zip_with(true_positive_ary, pred_sum_ary, (tp, ps) -> cast(tp as double)/ps),
CAST(ROW(0.0, 0) AS ROW(sum DOUBLE, count INTEGER)),
(s, x) -> CAST(ROW(x + s.sum, s.count + 1) AS ROW(sum DOUBLE, count INTEGER)),
s -> IF(s.count = 0, NULL, s.sum / s.count)) as precision_macro
from (
select
max(case when key = 'true_positive' then reduce(cnt_ary, 0, (s, x) -> s + x, s -> s) end) as true_positive
, max(case when key = 'true_sum' then reduce(cnt_ary, 0, (s, x) -> s + x, s -> s) end) as true_sum
, max(case when key = 'pred_sum' then reduce(cnt_ary, 0, (s, x) -> s + x, s -> s) end) as pred_sum
, max(case when key = 'true_positive' then cnt_ary end) as true_positive_ary
, max(case when key = 'true_sum' then cnt_ary end) as true_sum_ary
, max(case when key = 'pred_sum' then cnt_ary end) as pred_sum_ary
from
agg_array
group by dummy
) t2
;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment