Last active
January 12, 2017 22:25
-
-
Save ckirkendall/3be676d658bbcb7d51cb26d8f5eeb9f2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns wine-fun.data | |
(:require [clojure.java.io :as io] | |
[clojure.data.csv :as csv])) | |
(def training-data | |
(let [data (with-open [in-file (io/reader "data/winequality-data.csv")] | |
(drop 1 (doall | |
(csv/read-csv in-file)))) | |
control-count (int (/ (count data) 10)) | |
input (mapv #(->> % | |
(take 11) | |
(map (fn [val] (Double/parseDouble val))) | |
(into [])) (drop control-count data)) | |
control (mapv #(->> % | |
(take 12) | |
(map (fn [val] (Double/parseDouble val))) | |
(into [])) (take control-count data)) | |
target (mapv #(Double/parseDouble (nth % 11)) | |
(drop control-count data))] | |
{:input input | |
:control control | |
:target target})) | |
(def test-data | |
(with-open [in-file (io/reader "data/winequality-solution-input.csv")] | |
(doall | |
(csv/read-csv in-file)))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns wine-fun.linear-regression | |
(:require [clojure.core.matrix :as m] | |
[clojure.core.matrix.operators :as op] | |
[clojure.core.matrix.linear :as linear] | |
[wine-fun.data :as data])) | |
(m/set-current-implementation :vectorz) | |
(defn hypot [thetas xs] | |
(m/mmul (m/transpose thetas) xs)) | |
(defn sum [func m] | |
(m/to-vector | |
(for [rc (func m)] | |
(m/esum rc)))) | |
(defn mult-rows []) | |
(defn gd-step [x y inv-rcnt thetas] | |
(m/add thetas | |
(sum m/columns | |
(let [tmp (m/mul (m/sub y | |
(sum m/rows | |
(m/mul x thetas))) | |
inv-rcnt)] | |
(m/emap-indexed | |
(fn [[i j] val] (* val (m/mget tmp i))) | |
x))))) | |
(defn batch-gradient-decent [x y alpha] | |
(let [cnt (m/column-count x) | |
rcnt (m/row-count x) | |
inv-rcnt (* (/ 1 rcnt) alpha)] | |
(loop [loop-cnt 0 | |
thetas (m/to-vector (repeat cnt 0))] | |
(let [v-thetas (gd-step x y inv-rcnt thetas)] | |
(if (or (> loop-cnt 10000) | |
(m/equals thetas v-thetas 0.000001)) | |
v-thetas | |
(do | |
(when (zero? (mod loop-cnt 1000)) | |
(println "T:" v-thetas) | |
(println "E:" (m/esum (m/abs (m/add thetas (m/mmul v-thetas -1.0)))))) | |
(recur (inc loop-cnt) v-thetas))))))) | |
(defn matrix-gradient-decent [x y] | |
(let [xt (m/transpose x) | |
a0 (m/mmul xt x) | |
a1 (m/inverse a0) | |
a2 (m/mmul a1 xt y)] | |
a2)) | |
(def reg-funcs | |
{:batch-gd batch-gradient-decent | |
:matrix-gd matrix-gradient-decent | |
:least-sqr linear/least-squares}) | |
(defn run | |
([func x y control] | |
(run func x y nil control)) | |
([func x y step control] | |
(time | |
(let [args (if step [x y step] [x y]) | |
thetas (apply (reg-funcs func) args) | |
error (loop [total-error 0 | |
[val & rst] control] | |
(let [quality (apply + (map #(* %1 %2) | |
thetas | |
(butlast val)))] | |
(if-not val | |
total-error | |
(recur (+ total-error | |
(Math/abs (- (last val) | |
quality))) | |
rst))))] | |
(println "Method:" (name func)) | |
(println "Thetas: " thetas) | |
(println "Error: " (/ error (count control))))))) | |
(comment | |
(let [control (:control data/training-data) | |
x (m/matrix (:input data/training-data)) | |
y (m/to-vector (:target data/training-data))] | |
(println "Starting") | |
(run :batch-gd x y 0.00004 control) | |
(run :matrix-gd x y control) | |
(run :least-sqr x y control))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Method: batch-gd | |
Thetas: [0.006431930492348725 -0.5528332715006099 0.06759091050641436 0.026698146203819523 -0.005776954243954035 0.006658957977980936 -0.00160701061370485 0.1429976619372882 0.477799819390825 0.18796491717313363 0.38301467221089935] | |
Error: 0.5630484543418387 | |
"Elapsed time: 5369544.763433 msecs" | |
Method: matrix-gd | |
Thetas: #vectorz/vector [-0.03834677746391399,-1.9709278477762922,-0.04878694353371084,0.026068265886936417,-0.6763137716126468,0.004331197888668588,-7.234522797890264E-4,1.8169396095584391,0.16274153572638014,0.3303310066445396,0.3836755986839736] | |
Error: 0.5420093578819823 | |
"Elapsed time: 4.606669 msecs" | |
Method: least-sqr | |
Thetas: #vectorz/vector [-0.038346777464427585,-1.9709278477787677,-0.048786943531438774,0.02606826588689189,-0.6763137716158678,0.004331197888660421,-7.234522797880334E-4,1.816939609577304,0.16274153572233313,0.33033100664458875,0.383675598683793] | |
Error: 0.5420093578819263 | |
"Elapsed time: 218.852223 msecs" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment