Skip to content

Instantly share code, notes, and snippets.

@Nipsuli
Created May 22, 2022 06:47
Show Gist options
  • Save Nipsuli/7a0bc73d31676ec8344c4b5f8c8f770d to your computer and use it in GitHub Desktop.
Save Nipsuli/7a0bc73d31676ec8344c4b5f8c8f770d to your computer and use it in GitHub Desktop.
Numerically stable parallel safe standard deviation in PostgreSQL
CREATE OR REPLACE FUNCTION _wstd_state(state numeric[3], val numeric, weight numeric)
RETURNS numeric[3] AS $$
DECLARE
s_n_1 CONSTANT numeric NOT NULL := state[1];
mu_n_1 CONSTANT numeric NOT NULL := state[2];
w_n_1 CONSTANT numeric NOT NULL := state[3];
s_n numeric;
mu_n numeric;
w_n numeric;
BEGIN
IF val IS NULL OR weight IS NULL THEN
RETURN state;
ELSE
/*
You can find these equations in nice format from
https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf
*/
w_n := w_n_1 + weight; -- (47)
IF w_n = 0 THEN
/*
No weight --> cannot have any values
let's just pass zeros on
*/
RETURN ARRAY[0.0, 0.0, 0.0];
ELSE
mu_n := mu_n_1 + (weight / w_n) * (val - mu_n_1); -- (53)
s_n := s_n_1 + weight * (val - mu_n_1) * (val - mu_n); -- (68)
RETURN ARRAY[s_n, mu_n, w_n];
END IF;
END IF;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE OR REPLACE FUNCTION _wstd_final(state numeric[3])
RETURNS numeric AS $$
DECLARE
s_n CONSTANT numeric NOT NULL := state[1];
w_n CONSTANT numeric NOT NULL := state[3];
BEGIN
IF w_n = 0 THEN
RETURN NULL;
ELSE
RETURN sqrt(s_n / w_n);
END IF;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE OR REPLACE FUNCTION _wstd_combine(state1 numeric[3], state2 numeric[3])
RETURNS numeric[3] AS $$
DECLARE
s_a CONSTANT numeric NOT NULL := state1[1];
mu_a CONSTANT numeric NOT NULL := state1[2];
w_a CONSTANT numeric NOT NULL := state1[3];
s_b CONSTANT numeric NOT NULL := state2[1];
mu_b CONSTANT numeric NOT NULL := state2[2];
w_b CONSTANT numeric NOT NULL := state2[3];
w_ab numeric;
mu_ab numeric;
s_ab numeric;
delta numeric;
BEGIN
/*
Source:
Schubert, Erich, and Michael Gertz.
"Numerically stable parallel computation of (co-) variance."
Proceedings of the 30th International Conference on Scientific and Statistical Database Management. ACM, 2018.
(Won the SSDBM best-paper award.)
DOI:10.1145/3221269.3223036
For simpler presentation check wikipedia:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
*/
w_ab := w_a + w_b;
IF w_ab = 0 THEN
/*
This can happen when both parallel branches were actually empty and both have the initial zeros array
in this case let's continue with the zeros array
*/
return ARRAY[0.0, 0.0, 0.0];
ELSE
mu_ab := (mu_a * w_a + mu_b * w_b) / w_ab;
delta := mu_b - mu_a;
s_ab := s_a + s_b + (power(delta, 2) * w_a * w_b) / w_ab;
RETURN ARRAY[s_ab, mu_ab, w_ab];
END IF;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE OR REPLACE AGGREGATE weighted_std(val numeric, weight numeric) (
SFUNC=_wstd_state,
STYPE=numeric[3],
FINALFUNC=_wstd_final,
COMBINEFUNC=_wstd_combine,
INITCOND='{0.0, 0.0, 0.0}',
PARALLEL = SAFE
);
COMMENT ON AGGREGATE weighted_std(numeric, numeric) IS 'This is not 100% accurrate method, but the difference is super marginal. This approximation is faster to calculate than the "real weighted standard deviation" in sql context as it requires only single pass over data and is parallelizable. Source: Schubert, Erich, and Michael Gertz. "Numerically stable parallel computation of (co-) variance." Proceedings of the 30th International Conference on Scientific and Statistical Database Management. ACM, 2018. (Won the SSDBM best-paper award.) DOI:10.1145/3221269.3223036';
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment