Skip to content

Instantly share code, notes, and snippets.

@KingMob
Created April 15, 2024 12:55
Show Gist options
  • Save KingMob/24e0c0c3d56efe0b8a285b543556c7a5 to your computer and use it in GitHub Desktop.
Save KingMob/24e0c0c3d56efe0b8a285b543556c7a5 to your computer and use it in GitHub Desktop.
Incremental math optimization of a Clojure Gaussian logpdf fn
;; basic version
(set! *unchecked-math* false)
(defn gaussian-logpdf
"Returns log probability of `x` under a gaussian distribution parameterized
by shape parameter `mu`, with optional scale parameter `sigma`."
[x {:keys [mu sigma]}]
(let [z-inv (* -0.5 (+ (math/log sigma)
(math/log 2.0)
(math/log math/PI)))
px (* -0.5 (math/pow (/ (- x mu)
sigma)
2.0))]
(+ z-inv px)))
(set! *unchecked-math* true)
(defn gaussian-logpdf-unchecked
"Returns log probability of `x` under a gaussian distribution parameterized
by shape parameter `mu`, with optional scale parameter `sigma`."
[x {:keys [mu sigma]}]
(let [z-inv (* -0.5 (+ (math/log sigma)
(math/log 2.0)
(math/log math/PI)))
px (* -0.5 (math/pow (/ (- x mu)
sigma)
2.0))]
(+ z-inv px)))
(defn gaussian-logpdf-hint-or-cast
"Returns log probability of `x` under a gaussian distribution parameterized
by shape parameter `mu`, with optional scale parameter `sigma`."
;; Hint the returned type here, the compiler won't figure it out,
;; causing boxing/unboxing when fns call each other
^double
;; Hint the primitive argument. Works with fns up to 4 args, iirc
;; Doesn't work with destructuring, unfortunately
[^double x {:keys [mu sigma]}]
;; Do "casts" of mu/sigma... one final boxed operation, but at least the
;; compiler knows what they are from then on
(let [mu (double mu)
sigma (double sigma)
z-inv (* -0.5 (+ (math/log sigma)
(math/log 2.0)
(math/log math/PI)))
px (* -0.5 (math/pow (/ (- x mu)
sigma)
2.0))]
(+ z-inv px)))
(defn gaussian-logpdf-hinted
"Returns log probability of `x` under a gaussian distribution parameterized
by shape parameter `mu`, with optional scale parameter `sigma`."
^double
[^double x {:keys [mu sigma]}]
(let [z-inv (* -0.5 (+ (math/log ^double sigma) ;hint
(math/log 2.0)
(math/log math/PI)))
px (* -0.5 (math/pow (/ (- x ^double mu) ;hint
^double sigma) ;hint
2.0))]
(+ z-inv px)))
(defn gaussian-logpdf-hint-let
"Returns log probability of `x` under a gaussian distribution parameterized
by shape parameter `mu`, with optional scale parameter `sigma`."
^double
[^double x {:keys [mu sigma]}]
;; Apply the hints to let locals to tell the compiler what they are
(let [^double mu mu
^double sigma sigma
z-inv (* -0.5 (+ (math/log sigma)
(math/log 2.0)
(math/log math/PI)))
px (* -0.5 (math/pow (/ (- x mu)
sigma)
2.0))]
(+ z-inv px)))
(def ^:const log-of-2 (math/log 2.0))
(def ^:const log-of-pi (math/log math/PI))
(def ^:const log2+logpi (+ log-of-2 log-of-pi))
;; Stop recomputing log(2), log(PI), and their sum.
;; Also multiply by -0.5 once instead of twice
(defn gaussian-logpdf-hinted-precomputed
"Returns log probability of `x` under a gaussian distribution parameterized
by shape parameter `mu`, with optional scale parameter `sigma`."
^double
[^double x {:keys [mu sigma]}]
(let [z-inv (+ (math/log ^double sigma) ;hint
log2+logpi)
px (math/pow (/ (- x ^double mu) ;hint
^double sigma) ;hint
2.0)]
(* -0.5 (+ z-inv px)))) ; just multiple by -0.5 once here
(defn gaussian-logpdf-hinted-precomputed-squared
"Returns log probability of `x` under a gaussian distribution parameterized
by shape parameter `mu`, with optional scale parameter `sigma`."
^double
[^double x {:keys [mu sigma]}]
(let [z-inv (+ (math/log ^double sigma) ;hint
log2+logpi)
;; math/pow special-cases power-of-2 as x*x
;; depending on complexity, may be worth computing
;; x only once
px (* (/ (- x ^double mu) ;hint
^double sigma)
(/ (- x ^double mu) ;hint
^double sigma))]
(* -0.5 (+ z-inv px))))
(defn gaussian-logpdf-hinted-precomputed-get
"Returns log probability of `x` under a gaussian distribution parameterized
by shape parameter `mu`, with optional scale parameter `sigma`."
^double
[^double x m]
;; Use get directly instead of destructuring
(let [^double mu (get m :mu)
^double sigma (get m :sigma)
z-inv (+ (math/log ^double sigma) ;hint
log2+logpi)
px (math/pow (/ (- x ^double mu) ;hint
^double sigma) ;hint
2.0)]
(* -0.5 (+ z-inv px))))
(comment
(require '[criterium.core :as crit])
(let [n 15
x 1.27345
gaussian-params {:mu 0.1 :sigma 1.1}]
(assert (= (gaussian-logpdf x gaussian-params)
(gaussian-logpdf-unchecked x gaussian-params)
(gaussian-logpdf-hint-or-cast x gaussian-params)
(gaussian-logpdf-hinted x gaussian-params)
(gaussian-logpdf-hint-let x gaussian-params)
(gaussian-logpdf-hinted-precomputed x gaussian-params)
(gaussian-logpdf-hinted-precomputed-squared x gaussian-params)
(gaussian-logpdf-hinted-precomputed-get x gaussian-params)))
(println "\n>>>> gaussian-logpdf >>>>")
(crit/quick-bench
(dotimes [_ n]
(gaussian-logpdf x gaussian-params)))
(println "\n>>>> gaussian-logpdf-unchecked >>>>")
(crit/quick-bench
(dotimes [_ n]
(gaussian-logpdf-unchecked x gaussian-params)))
(println "\n>>>> gaussian-logpdf-hint-or-cast >>>>")
(crit/quick-bench
(dotimes [_ n]
(gaussian-logpdf-hint-or-cast x gaussian-params)))
(println "\n>>>> gaussian-logpdf-hinted >>>>")
(crit/quick-bench
(dotimes [_ n]
(gaussian-logpdf-hinted x gaussian-params)))
(println "\n>>>> gaussian-logpdf-hint-let >>>>")
(crit/quick-bench
(dotimes [_ n]
(gaussian-logpdf-hint-let x gaussian-params)))
(println "\n>>>> gaussian-logpdf-hinted-precomputed >>>>")
(crit/quick-bench
(dotimes [_ n]
(gaussian-logpdf-hinted-precomputed x gaussian-params)))
(println "\n>>>> gaussian-logpdf-hinted-precomputed-squared >>>>")
(crit/quick-bench
(dotimes [_ n]
(gaussian-logpdf-hinted-precomputed-squared x gaussian-params)))
(println "\n>>>> gaussian-logpdf-hinted-precomputed-get >>>>")
(crit/quick-bench
(dotimes [_ n]
(gaussian-logpdf-hinted-precomputed-get x gaussian-params))))
)
(def output "
>>>> gaussian-logpdf >>>>
Evaluation count : 638088 in 6 samples of 106348 calls.
Execution time mean : 1.028064 µs
Execution time std-deviation : 85.231381 ns
Execution time lower quantile : 922.171362 ns ( 2.5%)
Execution time upper quantile : 1.134582 µs (97.5%)
Overhead used : 6.514941 ns
>>>> gaussian-logpdf-unchecked >>>>
Evaluation count : 682296 in 6 samples of 113716 calls.
Execution time mean : 929.249972 ns
Execution time std-deviation : 37.452359 ns
Execution time lower quantile : 886.171568 ns ( 2.5%)
Execution time upper quantile : 977.026852 ns (97.5%)
Overhead used : 6.514941 ns
>>>> gaussian-logpdf-hint-or-cast >>>>
Evaluation count : 731484 in 6 samples of 121914 calls.
Execution time mean : 862.655163 ns
Execution time std-deviation : 59.972538 ns
Execution time lower quantile : 812.251882 ns ( 2.5%)
Execution time upper quantile : 954.780409 ns (97.5%)
Overhead used : 6.514941 ns
>>>> gaussian-logpdf-hinted >>>>
Evaluation count : 754440 in 6 samples of 125740 calls.
Execution time mean : 814.168622 ns
Execution time std-deviation : 23.058194 ns
Execution time lower quantile : 793.655305 ns ( 2.5%)
Execution time upper quantile : 840.419890 ns (97.5%)
Overhead used : 6.514941 ns
>>>> gaussian-logpdf-hint-let >>>>
Evaluation count : 759960 in 6 samples of 126660 calls.
Execution time mean : 818.227551 ns
Execution time std-deviation : 18.170133 ns
Execution time lower quantile : 799.199021 ns ( 2.5%)
Execution time upper quantile : 842.498653 ns (97.5%)
Overhead used : 6.514941 ns
>>>> gaussian-logpdf-hinted-precomputed >>>>
Evaluation count : 940902 in 6 samples of 156817 calls.
Execution time mean : 653.033033 ns
Execution time std-deviation : 37.353899 ns
Execution time lower quantile : 618.358788 ns ( 2.5%)
Execution time upper quantile : 701.936384 ns (97.5%)
Overhead used : 6.514941 ns
>>>> gaussian-logpdf-hinted-precomputed-squared >>>>
Evaluation count : 949932 in 6 samples of 158322 calls.
Execution time mean : 657.782670 ns
Execution time std-deviation : 33.428286 ns
Execution time lower quantile : 622.689841 ns ( 2.5%)
Execution time upper quantile : 704.922257 ns (97.5%)
Overhead used : 6.514941 ns
>>>> gaussian-logpdf-hinted-precomputed-get >>>>
Evaluation count : 2810646 in 6 samples of 468441 calls.
Execution time mean : 219.117072 ns
Execution time std-deviation : 9.265996 ns
Execution time lower quantile : 208.328944 ns ( 2.5%)
Execution time upper quantile : 229.442202 ns (97.5%)
Overhead used : 6.514941 ns
=> nil
")
(def comments "
By the end, we've shaved off 80% of the run time. (Probably more,
because there's a certain amount of overhead here in benching.)
We started with math improvements, but slowdowns are frequently
not where you expect. The biggest improvement comes from replacing
destructuring with the use of plain `get`.
The 2nd biggest win comes from not recomputing log(2.0), log(pi), and
their sum. In theory, a sufficiently smart compiler can detect that
these are constants, and lift them out for you. The Clojure compiler is
not that smart.
To see what's going on at the Java level, make sure the `classes` dir
exists, use the `compile` fn, like `(compile (ns-name *ns*))`, and open
up the classfiles.
In hinted vs hint-or-cast, it's hard to predict which is better a priori,
and they're pretty close. hint-or-cast uses casts early on to ensure the use of
primitive math afterwards, but those casts are boxed, because the compiler
doesn't know the inputs. hinted saves on the cost of those initial casts,
while requiring unchecked casts later. I would expect that, the more math
done, the better hint-or-cast will do.
")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment