Created
April 15, 2024 12:55
-
-
Save KingMob/24e0c0c3d56efe0b8a285b543556c7a5 to your computer and use it in GitHub Desktop.
Incremental math optimization of a Clojure Gaussian logpdf fn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
;; basic version | |
(set! *unchecked-math* false) | |
(defn gaussian-logpdf | |
"Returns log probability of `x` under a gaussian distribution parameterized | |
by shape parameter `mu`, with optional scale parameter `sigma`." | |
[x {:keys [mu sigma]}] | |
(let [z-inv (* -0.5 (+ (math/log sigma) | |
(math/log 2.0) | |
(math/log math/PI))) | |
px (* -0.5 (math/pow (/ (- x mu) | |
sigma) | |
2.0))] | |
(+ z-inv px))) | |
(set! *unchecked-math* true) | |
(defn gaussian-logpdf-unchecked | |
"Returns log probability of `x` under a gaussian distribution parameterized | |
by shape parameter `mu`, with optional scale parameter `sigma`." | |
[x {:keys [mu sigma]}] | |
(let [z-inv (* -0.5 (+ (math/log sigma) | |
(math/log 2.0) | |
(math/log math/PI))) | |
px (* -0.5 (math/pow (/ (- x mu) | |
sigma) | |
2.0))] | |
(+ z-inv px))) | |
(defn gaussian-logpdf-hint-or-cast | |
"Returns log probability of `x` under a gaussian distribution parameterized | |
by shape parameter `mu`, with optional scale parameter `sigma`." | |
;; Hint the returned type here, the compiler won't figure it out, | |
;; causing boxing/unboxing when fns call each other | |
^double | |
;; Hint the primitive argument. Works with fns up to 4 args, iirc | |
;; Doesn't work with destructuring, unfortunately | |
[^double x {:keys [mu sigma]}] | |
;; Do "casts" of mu/sigma... one final boxed operation, but at least the | |
;; compiler knows what they are from then on | |
(let [mu (double mu) | |
sigma (double sigma) | |
z-inv (* -0.5 (+ (math/log sigma) | |
(math/log 2.0) | |
(math/log math/PI))) | |
px (* -0.5 (math/pow (/ (- x mu) | |
sigma) | |
2.0))] | |
(+ z-inv px))) | |
(defn gaussian-logpdf-hinted | |
"Returns log probability of `x` under a gaussian distribution parameterized | |
by shape parameter `mu`, with optional scale parameter `sigma`." | |
^double | |
[^double x {:keys [mu sigma]}] | |
(let [z-inv (* -0.5 (+ (math/log ^double sigma) ;hint | |
(math/log 2.0) | |
(math/log math/PI))) | |
px (* -0.5 (math/pow (/ (- x ^double mu) ;hint | |
^double sigma) ;hint | |
2.0))] | |
(+ z-inv px))) | |
(defn gaussian-logpdf-hint-let | |
"Returns log probability of `x` under a gaussian distribution parameterized | |
by shape parameter `mu`, with optional scale parameter `sigma`." | |
^double | |
[^double x {:keys [mu sigma]}] | |
;; Apply the hints to let locals to tell the compiler what they are | |
(let [^double mu mu | |
^double sigma sigma | |
z-inv (* -0.5 (+ (math/log sigma) | |
(math/log 2.0) | |
(math/log math/PI))) | |
px (* -0.5 (math/pow (/ (- x mu) | |
sigma) | |
2.0))] | |
(+ z-inv px))) | |
(def ^:const log-of-2 (math/log 2.0)) | |
(def ^:const log-of-pi (math/log math/PI)) | |
(def ^:const log2+logpi (+ log-of-2 log-of-pi)) | |
;; Stop recomputing log(2), log(PI), and their sum. | |
;; Also multiply by -0.5 once instead of twice | |
(defn gaussian-logpdf-hinted-precomputed | |
"Returns log probability of `x` under a gaussian distribution parameterized | |
by shape parameter `mu`, with optional scale parameter `sigma`." | |
^double | |
[^double x {:keys [mu sigma]}] | |
(let [z-inv (+ (math/log ^double sigma) ;hint | |
log2+logpi) | |
px (math/pow (/ (- x ^double mu) ;hint | |
^double sigma) ;hint | |
2.0)] | |
(* -0.5 (+ z-inv px)))) ; just multiple by -0.5 once here | |
(defn gaussian-logpdf-hinted-precomputed-squared | |
"Returns log probability of `x` under a gaussian distribution parameterized | |
by shape parameter `mu`, with optional scale parameter `sigma`." | |
^double | |
[^double x {:keys [mu sigma]}] | |
(let [z-inv (+ (math/log ^double sigma) ;hint | |
log2+logpi) | |
;; math/pow special-cases power-of-2 as x*x | |
;; depending on complexity, may be worth computing | |
;; x only once | |
px (* (/ (- x ^double mu) ;hint | |
^double sigma) | |
(/ (- x ^double mu) ;hint | |
^double sigma))] | |
(* -0.5 (+ z-inv px)))) | |
(defn gaussian-logpdf-hinted-precomputed-get | |
"Returns log probability of `x` under a gaussian distribution parameterized | |
by shape parameter `mu`, with optional scale parameter `sigma`." | |
^double | |
[^double x m] | |
;; Use get directly instead of destructuring | |
(let [^double mu (get m :mu) | |
^double sigma (get m :sigma) | |
z-inv (+ (math/log ^double sigma) ;hint | |
log2+logpi) | |
px (math/pow (/ (- x ^double mu) ;hint | |
^double sigma) ;hint | |
2.0)] | |
(* -0.5 (+ z-inv px)))) | |
(comment | |
(require '[criterium.core :as crit]) | |
(let [n 15 | |
x 1.27345 | |
gaussian-params {:mu 0.1 :sigma 1.1}] | |
(assert (= (gaussian-logpdf x gaussian-params) | |
(gaussian-logpdf-unchecked x gaussian-params) | |
(gaussian-logpdf-hint-or-cast x gaussian-params) | |
(gaussian-logpdf-hinted x gaussian-params) | |
(gaussian-logpdf-hint-let x gaussian-params) | |
(gaussian-logpdf-hinted-precomputed x gaussian-params) | |
(gaussian-logpdf-hinted-precomputed-squared x gaussian-params) | |
(gaussian-logpdf-hinted-precomputed-get x gaussian-params))) | |
(println "\n>>>> gaussian-logpdf >>>>") | |
(crit/quick-bench | |
(dotimes [_ n] | |
(gaussian-logpdf x gaussian-params))) | |
(println "\n>>>> gaussian-logpdf-unchecked >>>>") | |
(crit/quick-bench | |
(dotimes [_ n] | |
(gaussian-logpdf-unchecked x gaussian-params))) | |
(println "\n>>>> gaussian-logpdf-hint-or-cast >>>>") | |
(crit/quick-bench | |
(dotimes [_ n] | |
(gaussian-logpdf-hint-or-cast x gaussian-params))) | |
(println "\n>>>> gaussian-logpdf-hinted >>>>") | |
(crit/quick-bench | |
(dotimes [_ n] | |
(gaussian-logpdf-hinted x gaussian-params))) | |
(println "\n>>>> gaussian-logpdf-hint-let >>>>") | |
(crit/quick-bench | |
(dotimes [_ n] | |
(gaussian-logpdf-hint-let x gaussian-params))) | |
(println "\n>>>> gaussian-logpdf-hinted-precomputed >>>>") | |
(crit/quick-bench | |
(dotimes [_ n] | |
(gaussian-logpdf-hinted-precomputed x gaussian-params))) | |
(println "\n>>>> gaussian-logpdf-hinted-precomputed-squared >>>>") | |
(crit/quick-bench | |
(dotimes [_ n] | |
(gaussian-logpdf-hinted-precomputed-squared x gaussian-params))) | |
(println "\n>>>> gaussian-logpdf-hinted-precomputed-get >>>>") | |
(crit/quick-bench | |
(dotimes [_ n] | |
(gaussian-logpdf-hinted-precomputed-get x gaussian-params)))) | |
) | |
(def output " | |
>>>> gaussian-logpdf >>>> | |
Evaluation count : 638088 in 6 samples of 106348 calls. | |
Execution time mean : 1.028064 µs | |
Execution time std-deviation : 85.231381 ns | |
Execution time lower quantile : 922.171362 ns ( 2.5%) | |
Execution time upper quantile : 1.134582 µs (97.5%) | |
Overhead used : 6.514941 ns | |
>>>> gaussian-logpdf-unchecked >>>> | |
Evaluation count : 682296 in 6 samples of 113716 calls. | |
Execution time mean : 929.249972 ns | |
Execution time std-deviation : 37.452359 ns | |
Execution time lower quantile : 886.171568 ns ( 2.5%) | |
Execution time upper quantile : 977.026852 ns (97.5%) | |
Overhead used : 6.514941 ns | |
>>>> gaussian-logpdf-hint-or-cast >>>> | |
Evaluation count : 731484 in 6 samples of 121914 calls. | |
Execution time mean : 862.655163 ns | |
Execution time std-deviation : 59.972538 ns | |
Execution time lower quantile : 812.251882 ns ( 2.5%) | |
Execution time upper quantile : 954.780409 ns (97.5%) | |
Overhead used : 6.514941 ns | |
>>>> gaussian-logpdf-hinted >>>> | |
Evaluation count : 754440 in 6 samples of 125740 calls. | |
Execution time mean : 814.168622 ns | |
Execution time std-deviation : 23.058194 ns | |
Execution time lower quantile : 793.655305 ns ( 2.5%) | |
Execution time upper quantile : 840.419890 ns (97.5%) | |
Overhead used : 6.514941 ns | |
>>>> gaussian-logpdf-hint-let >>>> | |
Evaluation count : 759960 in 6 samples of 126660 calls. | |
Execution time mean : 818.227551 ns | |
Execution time std-deviation : 18.170133 ns | |
Execution time lower quantile : 799.199021 ns ( 2.5%) | |
Execution time upper quantile : 842.498653 ns (97.5%) | |
Overhead used : 6.514941 ns | |
>>>> gaussian-logpdf-hinted-precomputed >>>> | |
Evaluation count : 940902 in 6 samples of 156817 calls. | |
Execution time mean : 653.033033 ns | |
Execution time std-deviation : 37.353899 ns | |
Execution time lower quantile : 618.358788 ns ( 2.5%) | |
Execution time upper quantile : 701.936384 ns (97.5%) | |
Overhead used : 6.514941 ns | |
>>>> gaussian-logpdf-hinted-precomputed-squared >>>> | |
Evaluation count : 949932 in 6 samples of 158322 calls. | |
Execution time mean : 657.782670 ns | |
Execution time std-deviation : 33.428286 ns | |
Execution time lower quantile : 622.689841 ns ( 2.5%) | |
Execution time upper quantile : 704.922257 ns (97.5%) | |
Overhead used : 6.514941 ns | |
>>>> gaussian-logpdf-hinted-precomputed-get >>>> | |
Evaluation count : 2810646 in 6 samples of 468441 calls. | |
Execution time mean : 219.117072 ns | |
Execution time std-deviation : 9.265996 ns | |
Execution time lower quantile : 208.328944 ns ( 2.5%) | |
Execution time upper quantile : 229.442202 ns (97.5%) | |
Overhead used : 6.514941 ns | |
=> nil | |
") | |
(def comments " | |
By the end, we've shaved off 80% of the run time. (Probably more, | |
because there's a certain amount of overhead here in benching.) | |
We started with math improvements, but slowdowns are frequently | |
not where you expect. The biggest improvement comes from replacing | |
destructuring with the use of plain `get`. | |
The 2nd biggest win comes from not recomputing log(2.0), log(pi), and | |
their sum. In theory, a sufficiently smart compiler can detect that | |
these are constants, and lift them out for you. The Clojure compiler is | |
not that smart. | |
To see what's going on at the Java level, make sure the `classes` dir | |
exists, use the `compile` fn, like `(compile (ns-name *ns*))`, and open | |
up the classfiles. | |
In hinted vs hint-or-cast, it's hard to predict which is better a priori, | |
and they're pretty close. hint-or-cast uses casts early on to ensure the use of | |
primitive math afterwards, but those casts are boxed, because the compiler | |
doesn't know the inputs. hinted saves on the cost of those initial casts, | |
while requiring unchecked casts later. I would expect that, the more math | |
done, the better hint-or-cast will do. | |
") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment