Skip to content

Instantly share code, notes, and snippets.

@sritchie
Last active January 18, 2021 11:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sritchie/deb9ef8987aef01f675a24e102b86881 to your computer and use it in GitHub Desktop.
Save sritchie/deb9ef8987aef01f675a24e102b86881 to your computer and use it in GitHub Desktop.
multivariate redo attempt...
;;
;; [[jacobian]] handles this main logic. [[jacobian]] can only take a structural
;; input. [[euclidean]] and [[multivariate]] below widen handle, respectively,
;; optionally-structural and multivariable arguments.
(defn- jacobian
"Takes:
- some function `f` of a single [[s/structure?]] argument
- the unperturbed structural `input`
- a `selectors` vector that can be empty or contain a valid path into the
`input` structure
and returns either:
- The
full [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
of `f` at `input`, if `selectors` is empty
- the entry of the Jacobian at `selectors`.
The Jacobian has the same shape as `input` (or the entry at `selectors`) with
all orientations flipped. Multiplying this by an increment in the shape of
`input` will give you a proper increment in the output of `f`."
([f input] (jacobian f input []))
([f input selectors]
(letfn [(prefixed [path]
(if (empty? selectors)
path
(into selectors path)))
(substitute [path entry]
(assoc-in input (prefixed path) entry))]
(if-let [piece (get-in input selectors)]
(let [frame (s/transpose piece)
perturb-entry (fn [entry path]
(letfn [(f-entry [x]
(f (substitute path x)))]
;; Each entry takes the derivative of a function
;; of THAT entry; internally, `f-entry`
;; substitutes the perturbed entry into the
;; appropriate place in the full `input` before
;; calling `f`.
((derivative f-entry) entry)))]
;; Visit each entry in `frame`, a copy of either the full input or the
;; sub-piece living at `selectors` (with all orientations flipped), and
;; replace the entry with the result of the partial derivative of `f`
;; with that entry perturbed.
(s/map-chain (fn [entry path _]
(if (v/numerical? entry)
(perturb-entry entry path)
(u/illegal
(str "non-numerical entry " entry " in input structure " input))))
frame))
;; The call to `get-in` will return nil if the `selectors` don't index
;; correctly into the supplied `input`, triggering this exception.
(u/illegal (str "Bad selectors " selectors " for structure " input))))))
(defn- euclidean
"Slightly more general version of [[jacobian]] that can handle a single
non-structural input; dispatches to either [[jacobian]] or [[derivative]]
depending on the input type.
If you pass non-empty `selectors`, the returned function will throw if it
receives a non-structural, non-numerical argument."
([f] (euclidean f []))
([f selectors]
(let [selectors (vec selectors)]
(fn [input]
(cond (s/structure? input)
(jacobian f input selectors)
;; non-empty selectors are only allowed for functions that receive
;; a structural argument. This case passes that single,
;; non-structural argument on to `(derivative f)`.
(empty? selectors)
((derivative f) input)
;; Any attempt to index (via non-empty selectors) into a
;; non-structural argument will throw.
;;
;; NOTE: What about matrices, maps or sequences? The current
;; implementation (as of 0.14.0) pushes the derivative operator
;; into the entries, or values, of those types, so they won't reach
;; this clause. There is a case I (@sritchie) can make for actually
;; allowing the first clause here to work for ANY associative
;; structure; then you're on your own if you want to call this fn
;; directly.
:else
(u/illegal
(str "Selectors " selectors
" not allowed for non-structural input " input)))))))
(defn- multivariate
"Slightly wider version of [[euclidean]]. Accepts:
- some function `f` of potentially many arguments
- optionally, a sequence of selectors meant to index into the structural
argument, or argument vector, of `f`
And returns a new function that computes either the
full [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
or the entry at `selectors`.
Any multivariable function will have its argument vector coerced into an `up`
structure. Any [[matrix/Matrix]] in a multiple-arg function call will be
converted into a `down` of `up`s (a row of columns).
Single-argument functions don't transform their arguments."
([f] (multivariate f []))
([f selectors]
(let [d #(euclidean % selectors)]
(-> (fn
([] (constantly 0))
([x] ((d f) x))
([x & more]
((d #(apply f %))
(matrix/seq-> (cons x more)))))
(f/with-arity (f/arity f) {:from ::multivariate})))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment