Numerically stable math
(ns jupiter.test.utils.math
(:require [clojure.test :refer :all]
[jupiter.utils.math :refer :all]))
(defn float=
([x y] (float= x y 0.000001))
([x y epsilon]
(let [scale (if (or (zero? x) (zero? y)) 1 (Math/abs x))]
(<= (Math/abs (- x y)) (* scale epsilon)))))
(deftest test-k+
(testing "basics"
(is (= [12 0] (k+ 6 6)))
(is (= [12.1 0.0] (k+ [6 0.1] 6)))
(is (= [12.1 0.0] (k+ 6 [6 0.1])))
(is (= [12 0] (k+ [6 0] 6))))
(testing "numerical stability at double precision"
(let [iterations 300000 ; crank this up and look at the
; regular-addition-error-term grow...
(+ -1 (reduce + (repeat iterations (double (/ 1 iterations)))))
kahan-result (reduce k+ (repeat iterations (double (/ 1 iterations))))
kahan-addition-error-term (+ -1
(first kahan-result)
(second kahan-result))]
(is (> regular-addition-error-term 1E-12))
(is (< kahan-addition-error-term 1E-15)))))
(defn square [x]
(* x x))
(deftest test-incremental-arithmetic
(let [iterations 1000
denominator 10
integer-series (repeatedly iterations #(rand-int denominator))
number-series (map #(double (/ % denominator)) integer-series)
correctable-series (map #(vector % 1) number-series)
accuracy 0.0000000000000001]
(testing "mean"
(let [calculated-mean (double (/ (reduce + integer-series) iterations denominator))
naive-incremental-mean (double
(reduce naive-incremental-mean
stable-incremental-mean (double
(reduce stable-incremental-mean
(is (float= calculated-mean naive-incremental-mean))
(is (not (float= calculated-mean naive-incremental-mean accuracy)))
(is (float= calculated-mean stable-incremental-mean accuracy)))
;; TODO: test that incremental averaging on partitions of the
;; series yields good results
(testing "standard deviation"
(let [data-series [2 4 4 4 5 5 7 9]
calculated-mean (/ (reduce + data-series) (count data-series))
differences (map (partial + (- calculated-mean)) data-series)
square-differences (map square differences)
variance (/ (reduce + square-differences) (count square-differences))
calculated-std-dev (Math/sqrt variance)]
(ns jupiter.utils.math)
(defprotocol StableArithmetic
(stable+ [this] [this y]))
(defrecord Correctable [value correction]
(stable+ [this] this)
(stable+ [this y] (+ this y)))
(extend-type java.lang.Number
(stable+ [this] this)
(stable+ [this y] (clojure.core/+ this y)))
(defn correctable [n]
(number? n) [n 0]
(and (sequential? n)
(number? (first n))
(number? (second n)))
:else (throw (Exception. "Cannot convert that to a correctable number."))))
(defn correct [n]
(let [n (correctable n)]
(+ (first n) (second n))))
(defn k+
"Kahan Increment: numerically stable summation of two numbers, by
keeping track of correction terms. Accepts either numbers or
2-tuples of [value correction]. Returns a [value correction] tuple.
[x y]
(let [[s1 c1] (correctable x)
[s2 c2] (correctable y)
corrected-s2 (+ s2 c1 c2)
sum (stable+ s1 corrected-s2)
correction (- corrected-s2 (- sum s1))]
[sum correction]))
(defn naive-incremental-mean
[[mu-a n-a] [mu-b n-b]]
(let [n (+ n-a n-b)
sigma (- mu-b mu-a)]
[(+ mu-a (/ (* n-b sigma) n)) n]))
(defn stable-incremental-mean [[mu-a n-a] [mu-b n-b]]
(let [mu-a (correctable mu-a)
mu-b (correctable mu-b)
n (stable+ n-a n-b)
sigma (- (correct mu-b) (correct mu-a))]
[(k+ mu-a (/ (* n-b sigma) n)) n]))
(defn naive-incremental-std-dev
"Takes tuples of second central moment, mean, n of the subsets of
data to combine."
[[m2-a mu-a n-a] [m2-b mu-b n-b]]
(let [n (+ n-a n-b)
sigma (- mu-b mu-a)
mu (+ mu-a (/ (* n-b sigma) n))
M2-a (* m2-a n-a)
M2-b (* m2-b n-b)])
