Skip to content

Instantly share code, notes, and snippets.

@mthomure
Last active January 16, 2017 20:25
Show Gist options
  • Save mthomure/45e6606c5a625ba0baa3d92db0a821e0 to your computer and use it in GitHub Desktop.
Save mthomure/45e6606c5a625ba0baa3d92db0a821e0 to your computer and use it in GitHub Desktop.
First steps at clojure idiomatic wrapper for mallet
(ns learn-mallet.core
(:import [cc.mallet.optimize Optimizable$ByGradientValue
ConjugateGradient GradientAscent LimitedMemoryBFGS
OptimizerEvaluator$ByGradient]))
;; add dependency [cc.mallet/mallet "2.0.8"]
(defprotocol IProblem
;; returns map of problem's current state
(problem-state [this]))
(extend-protocol IProblem
Optimizable$ByGradientValue
(problem-state [^Optimizable$ByGradientValue this]
(let [n (.getNumParameters this)
params (double-array n)
grad (double-array n)]
(.getParameters this params)
(.getValueGradient this grad)
{:params (into [] params)
:value (.getValue this)
:gradient (into [] grad)})))
;; XXX is there a better way to do this?
(defn- coll->array! [coll arr]
(doall (map-indexed #(aset-double arr %1 %2) coll)))
(defn problem [f g initial-params]
(let [params (double-array initial-params)]
(reify Optimizable$ByGradientValue
(getNumParameters [this]
(alength params))
(getParameters [this out]
(System/arraycopy params 0 out 0 (alength params)))
(getParameter [this idx]
(aget params idx))
(setParameters [this new-params]
(System/arraycopy new-params 0 params 0 (alength params)))
(setParameter [this idx param]
(aset-double params idx param))
(getValue [this]
(f params))
(getValueGradient [this out]
(coll->array! (g params) out)))))
(defn ->evaluator [f]
(reify cc.mallet.optimize.OptimizerEvaluator$ByGradient
(evaluate [_ maxable iter]
(f maxable iter))))
(defn- optimize! [optimizer problem {:keys [max-iterations tolerance]}]
(when tolerance (.setTolerance optimizer tolerance))
(let [it (atom 0)
evaluator (fn [m i] (do (reset! it i) true))
_ (.setEvaluator optimizer (->evaluator evaluator))
[conv? e] (try
[(if max-iterations
(.optimize optimizer max-iterations)
(.optimize optimizer))
nil]
;; This exception may be thrown if L-BFGS cannot step in the
;; current direction. This condition does not necessarily
;; mean that the optimizer has failed, but it doesn't want
;; to claim to have succeeded...
;; XXX this is bad. we're probably swallowing real
;; exceptions, too.
(catch IllegalArgumentException e
[false e]))]
(merge (problem-state problem)
{:converged? conv?
:num-iterations @it}
(when e {:exception e}))))
(defn lbfgs! [problem & {:as args}]
(optimize! (LimitedMemoryBFGS. problem) problem args))
(defn conjugate-gradient! [problem & {:keys [step-size] :as args}]
(let [optimizer (if step-size
(ConjugateGradient. problem step-size)
(ConjugateGradient. problem))]
(optimize! optimizer problem args)))
(defn gradient-ascent!
[problem & {:keys [step-size] :as args}]
(let [optimizer (GradientAscent. problem)]
(when step-size (.setInitialStepSize optimizer step-size))
(optimize! optimizer problem args)))
;;;;;;;;;;;;;;;;;;;;;;
;; see http://mallet.cs.umass.edu/optimization.php
(defn problem-1 []
(let [val-fn (fn [[x y]]
(+
(* -3 x x)
(* -4 y y)
(* 2 x)
(* -4 y)
18))
grad-fn (fn [[x y]]
[(-> x (* -6) (+ 2))
(-> y (* -8) (- 4))])]
(problem val-fn grad-fn [0 0])))
(defn sqr [x]
(* x x))
;; see http://www.cas.mcmaster.ca/~cs4te3/tutorials/BFGS.pdf
(defn problem-2 []
(let [val-fn (fn [[x y]]
(-
(+ (Math/exp (dec x))
(Math/exp (inc (- y)))
(sqr (- x y)))))
grad-fn (fn [[x y]]
[(-
(+ (Math/exp (dec x))
(* 2 (- x y))))
(-
(+ (- (Math/exp (inc (- y))))
(* -2 (- x y))))])]
(problem val-fn grad-fn [0 0])))
(comment
(pprint (lm/lbfgs! (lm/problem-2))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment