Skip to content

Instantly share code, notes, and snippets.

@whilo
Created July 13, 2018 22:45
Show Gist options
  • Save whilo/fa88a15d69d4be7376a89d545dcbb1bc to your computer and use it in GitHub Desktop.
Save whilo/fa88a15d69d4be7376a89d545dcbb1bc to your computer and use it in GitHub Desktop.
forward autodiff experiments
(ns autodiff.forward
"Experiments for forward differentiation.")
;; Alternative idea: encode tabular SSA structure of Griewank chapter 3.1
;; Questions:
;; what SSA representation is good to implement compiler optimizations?
;; TODO
;; avoid vector-caused boxing on boundaries between SSA blocks
;; reverse mode
;; symbol hygene
;; stalin-grad style self-hosting (?)
;; nested values/datastructures?
;; anonymous functions
;; other operators: -, /, exp, log
(defn term? [exp]
(or (number? exp)
(symbol? exp)))
(defn dispatch-exp [exp]
(assert (seq? exp) "All differentiation happens on arithmetic expressions.")
(assert (any? (map seq? exp)) "Differentiation works on flat (not-nested) expressions only.")
(keyword (first exp)))
(defn ** [a b] (Math/pow a b)) ;; probably better support Math/pow (?)
(defmulti diff-exp dispatch-exp)
(defmethod diff-exp :+ [[_ & args]]
(seq (into '[+]
(reduce (fn [nargs a]
(if (symbol? a)
(conj nargs (symbol (str a "'")))
nargs))
[]
args))))
(comment
(diff-exp '(+ 1 2 x)))
(defmethod diff-exp :* [[_ & args]]
(let [args (vec args)]
(conj (for [i (range (count args))
:when (symbol? (args i))]
(seq (into '[*] (update args i #(symbol (str % "'"))))))
'+)))
(comment
(diff-exp '(* x y)))
(defmethod diff-exp :** [[_ & args]]
(let [[base expo] args]
(list '* expo (list '** base (list 'dec expo)))))
(comment
(diff-exp '(** x 3)))
;;;;;;;;;;;;;;;; SSA <-> clj ;;;;;;;;;;;;;;;;
(defn last-expr-or-recur [exp ssa]
(if (and (seq? exp)
(= (first exp) 'recur))
exp
(second (last ssa))))
(defn forward-ssa->clj
"Lift expression back into executable Clojure code."
[ssa]
(let [[exp diff] (second (last ssa))]
(if (= 1 (count ssa))
(last-expr-or-recur exp ssa)
;; lets unpack the assignments to avoid boxing
(let [assignments (mapcat first ssa)
values (mapcat (fn [exp]
(let [val (second exp)]
(if (vector? val) val
[val]))) ssa)]
(list 'let (vec (interleave assignments values))
(last-expr-or-recur exp ssa))))))
(defn dispatch-forward [sym exp ssa]
(cond (seq? exp)
(keyword (first exp))
(term? exp)
:term))
(defmulti clj->forward-ssa
"Uses symbol for expression to add it to SSA assignment trace. This will
recursively unwind all expressions and prepend them depth first to the SSA
trace before adding expression."
dispatch-forward)
(defn ssa-expr [sym exp ssa]
(let [[f & args] exp
nargs (map (fn [a] (if (term? a) a (gensym "v"))) args)
nexp (conj nargs f)
ssa (reduce (fn [ssa [s a]]
(if (term? a)
ssa
(clj->forward-ssa s a ssa)))
ssa
(partition 2 (interleave nargs args)))]
(conj ssa
[[sym (symbol (str sym "'"))]
;; forward diff happens here
[nexp (diff-exp nexp)]])))
(defmethod clj->forward-ssa :term
[sym exp ssa]
(conj ssa
[[sym (symbol (str sym "'"))]
[exp (if (symbol? exp) ;; TODO check for differentiated symbols
(symbol (str exp "'"))
0)]]))
(defmethod clj->forward-ssa :+
[sym exp ssa]
(ssa-expr sym exp ssa))
(defmethod clj->forward-ssa :*
[sym exp ssa]
(ssa-expr sym exp ssa))
(defmethod clj->forward-ssa :**
[sym exp ssa]
(ssa-expr sym exp ssa))
(comment
(clj->forward-ssa (gensym "v") '(+ 5 (* (+ x 7) x)) [])
(clj->forward-ssa (gensym "v") '(** x x) []))
(defmethod clj->forward-ssa :let
[sym exp ssa]
;; TODO support destructuring syntax
(let [[_ bindings body] exp
ssa-bindings (vec (reduce (fn [ssa [k v]]
(clj->forward-ssa k v ssa))
ssa
(partition 2 bindings)))]
(clj->forward-ssa (gensym) body ssa-bindings)))
(comment
(clojure.pprint/pprint
(clj->forward-ssa (gensym "v")
'(let [y (+ 5 (* (+ x 7) x))]
(* (+ z 5) (* x y)))
[])))
;; following Griewank p. 125 f
(defmethod clj->forward-ssa :if
[sym exp ssa]
(let [[_ cnd then else] exp]
(conj ssa
[[sym (symbol (str sym "'"))]
(list 'if cnd
(forward-ssa->clj
(clj->forward-ssa (gensym "v") then []))
(forward-ssa->clj
(clj->forward-ssa (gensym "v") else [])))])))
(comment
(clj->forward-ssa (gensym "v") '(if true (+ 42 x) (* 32 y)) []))
(defmethod clj->forward-ssa :recur
[sym exp ssa]
(conj ssa
[[(gensym "v") (gensym "v")]
[(conj (map #(forward-ssa->clj (clj->forward-ssa (gensym "v") % []))
(rest exp))
'recur) 0]]))
(comment
(clj->forward-ssa (gensym "v") '(recur (+ x 1)) []))
(defmethod clj->forward-ssa :loop
[sym exp ssa]
(let [[_ bindings body] exp
ssa-bindings (vec (reduce (fn [ssa [k v]]
(clj->forward-ssa k v ssa))
[]
(partition 2 bindings)))]
(conj ssa
[[(gensym "v") (gensym "v'")]
(list 'loop (vec (apply concat ssa-bindings))
(forward-ssa->clj
(clj->forward-ssa (gensym) body [])))])))
(comment
(clojure.pprint/pprint
(clj->forward-ssa (gensym "v") '(loop [x 0]
(if (< x 10)
(recur (+ x 1))
x))
[]))
(forward-ssa->clj (clj->forward-ssa (gensym "v") '(loop [x 0]
(if (< x 10)
(recur (+ (* x) 1))
x))
[])))
(defn forward-diff [exp]
(forward-ssa->clj
(clj->forward-ssa (gensym "v") exp [])))
(comment
(eval
(list 'let ['x 5
'y 3
'x' 1
'y' 0]
(forward-diff '(let [c (+ 4 x)]
(if (> c 0)
(+ (* x c) y)
5)))))
(==
(second
(eval
(list 'let ['x 5
'y 3
'x' 1
'y' 0]
(forward-diff '(let [c x]
(loop [i 0
d c]
(if (= i 5)
d
(recur (+ i 1) (* x d)))))))))
(* 6 (Math/pow 5 5))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment