Last active
June 27, 2017 17:47
-
-
Save joycex99/8142cceab23a20f6bc1adfd6bf11570e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(defn f-beta | |
"F-beta score, default uses F1" | |
([precision recall] (f-beta precision recall 1)) | |
([precision recall beta] | |
(let [beta-squared (* beta beta)] | |
(* (+ 1 beta-squared) | |
(try ;; catch divide by 0 errors | |
(/ (* precision recall) | |
(+ (* beta-squared precision) recall)) | |
(catch ArithmeticException e | |
0)))))) | |
(def high-score* (atom {:score 0})) | |
(defn train | |
"Train the network for epoch-count epochs, saving the best results as we go." | |
[] | |
(let [context (execute/compute-context)] ; determines context for gpu/cpu training | |
(execute/with-compute-context | |
context | |
(let [[train-ds test-ds] (get-train-test-dataset) | |
network (network/linear-network network-description)] | |
(reduce (fn [[network optimizer] epoch] | |
(let [{:keys [network optimizer]} (execute/train network train-ds | |
:context context | |
:batch-size (:batch-size params) | |
:optimizer optimizer) | |
test-results (execute/run network test-ds :context context | |
:batch-size (:batch-size params)) | |
test-actual (mapv #(vec->label [0.0 1.0] %) (map :label test-ds)) | |
test-pred (mapv #(vec->label [0.0 1.0] %) (map :label test-results)) | |
;;; test metrics | |
test-precision (metrics/precision test-actual test-pred) | |
test-recall (metrics/recall test-actual test-pred) | |
test-f-beta (f-beta test-precision test-recall) | |
test-accuracy (softmax-loss/evaluate-softmax | |
(map :label test-results) (map :label test-ds)) | |
] | |
(log (str "Epoch: " (inc epoch) "\n" | |
"Test accuracy: " test-accuracy "\n" | |
"Test precision: " test-precision "\n" | |
"Test recall: " test-recall "\n" | |
"Test F1: " test-f-beta "\n\n")) | |
(when (> test-f-beta (:score @high-score*)) | |
(reset! high-score* {:score test-f-beta}) | |
(save network)) | |
[network optimizer])) | |
[network (:optimizer params)] | |
(range (:epoch-count params))) | |
(println "Done.") | |
(log (str "Best score: " (:score @high-score*))))))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment