Created
May 22, 2022 06:47
-
-
Save Nipsuli/7a0bc73d31676ec8344c4b5f8c8f770d to your computer and use it in GitHub Desktop.
Numerically stable parallel safe standard deviation in PostgreSQL
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CREATE OR REPLACE FUNCTION _wstd_state(state numeric[3], val numeric, weight numeric) | |
RETURNS numeric[3] AS $$ | |
DECLARE | |
s_n_1 CONSTANT numeric NOT NULL := state[1]; | |
mu_n_1 CONSTANT numeric NOT NULL := state[2]; | |
w_n_1 CONSTANT numeric NOT NULL := state[3]; | |
s_n numeric; | |
mu_n numeric; | |
w_n numeric; | |
BEGIN | |
IF val IS NULL OR weight IS NULL THEN | |
RETURN state; | |
ELSE | |
/* | |
You can find these equations in nice format from | |
https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf | |
*/ | |
w_n := w_n_1 + weight; -- (47) | |
IF w_n = 0 THEN | |
/* | |
No weight --> cannot have any values | |
let's just pass zeros on | |
*/ | |
RETURN ARRAY[0.0, 0.0, 0.0]; | |
ELSE | |
mu_n := mu_n_1 + (weight / w_n) * (val - mu_n_1); -- (53) | |
s_n := s_n_1 + weight * (val - mu_n_1) * (val - mu_n); -- (68) | |
RETURN ARRAY[s_n, mu_n, w_n]; | |
END IF; | |
END IF; | |
END; | |
$$ LANGUAGE plpgsql IMMUTABLE; | |
CREATE OR REPLACE FUNCTION _wstd_final(state numeric[3]) | |
RETURNS numeric AS $$ | |
DECLARE | |
s_n CONSTANT numeric NOT NULL := state[1]; | |
w_n CONSTANT numeric NOT NULL := state[3]; | |
BEGIN | |
IF w_n = 0 THEN | |
RETURN NULL; | |
ELSE | |
RETURN sqrt(s_n / w_n); | |
END IF; | |
END; | |
$$ LANGUAGE plpgsql IMMUTABLE; | |
CREATE OR REPLACE FUNCTION _wstd_combine(state1 numeric[3], state2 numeric[3]) | |
RETURNS numeric[3] AS $$ | |
DECLARE | |
s_a CONSTANT numeric NOT NULL := state1[1]; | |
mu_a CONSTANT numeric NOT NULL := state1[2]; | |
w_a CONSTANT numeric NOT NULL := state1[3]; | |
s_b CONSTANT numeric NOT NULL := state2[1]; | |
mu_b CONSTANT numeric NOT NULL := state2[2]; | |
w_b CONSTANT numeric NOT NULL := state2[3]; | |
w_ab numeric; | |
mu_ab numeric; | |
s_ab numeric; | |
delta numeric; | |
BEGIN | |
/* | |
Source: | |
Schubert, Erich, and Michael Gertz. | |
"Numerically stable parallel computation of (co-) variance." | |
Proceedings of the 30th International Conference on Scientific and Statistical Database Management. ACM, 2018. | |
(Won the SSDBM best-paper award.) | |
DOI:10.1145/3221269.3223036 | |
For simpler presentation check wikipedia: | |
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm | |
*/ | |
w_ab := w_a + w_b; | |
IF w_ab = 0 THEN | |
/* | |
This can happen when both parallel branches were actually empty and both have the initial zeros array | |
in this case let's continue with the zeros array | |
*/ | |
return ARRAY[0.0, 0.0, 0.0]; | |
ELSE | |
mu_ab := (mu_a * w_a + mu_b * w_b) / w_ab; | |
delta := mu_b - mu_a; | |
s_ab := s_a + s_b + (power(delta, 2) * w_a * w_b) / w_ab; | |
RETURN ARRAY[s_ab, mu_ab, w_ab]; | |
END IF; | |
END; | |
$$ LANGUAGE plpgsql IMMUTABLE; | |
CREATE OR REPLACE AGGREGATE weighted_std(val numeric, weight numeric) ( | |
SFUNC=_wstd_state, | |
STYPE=numeric[3], | |
FINALFUNC=_wstd_final, | |
COMBINEFUNC=_wstd_combine, | |
INITCOND='{0.0, 0.0, 0.0}', | |
PARALLEL = SAFE | |
); | |
COMMENT ON AGGREGATE weighted_std(numeric, numeric) IS 'This is not 100% accurrate method, but the difference is super marginal. This approximation is faster to calculate than the "real weighted standard deviation" in sql context as it requires only single pass over data and is parallelizable. Source: Schubert, Erich, and Michael Gertz. "Numerically stable parallel computation of (co-) variance." Proceedings of the 30th International Conference on Scientific and Statistical Database Management. ACM, 2018. (Won the SSDBM best-paper award.) DOI:10.1145/3221269.3223036'; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment