Skip to content

Instantly share code, notes, and snippets.

@joycex99
Last active June 27, 2017 17:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joycex99/8142cceab23a20f6bc1adfd6bf11570e to your computer and use it in GitHub Desktop.
Save joycex99/8142cceab23a20f6bc1adfd6bf11570e to your computer and use it in GitHub Desktop.
(defn f-beta
"F-beta score, default uses F1"
([precision recall] (f-beta precision recall 1))
([precision recall beta]
(let [beta-squared (* beta beta)]
(* (+ 1 beta-squared)
(try ;; catch divide by 0 errors
(/ (* precision recall)
(+ (* beta-squared precision) recall))
(catch ArithmeticException e
0))))))
(def high-score* (atom {:score 0}))
(defn train
"Train the network for epoch-count epochs, saving the best results as we go."
[]
(let [context (execute/compute-context)] ; determines context for gpu/cpu training
(execute/with-compute-context
context
(let [[train-ds test-ds] (get-train-test-dataset)
network (network/linear-network network-description)]
(reduce (fn [[network optimizer] epoch]
(let [{:keys [network optimizer]} (execute/train network train-ds
:context context
:batch-size (:batch-size params)
:optimizer optimizer)
test-results (execute/run network test-ds :context context
:batch-size (:batch-size params))
test-actual (mapv #(vec->label [0.0 1.0] %) (map :label test-ds))
test-pred (mapv #(vec->label [0.0 1.0] %) (map :label test-results))
;;; test metrics
test-precision (metrics/precision test-actual test-pred)
test-recall (metrics/recall test-actual test-pred)
test-f-beta (f-beta test-precision test-recall)
test-accuracy (softmax-loss/evaluate-softmax
(map :label test-results) (map :label test-ds))
]
(log (str "Epoch: " (inc epoch) "\n"
"Test accuracy: " test-accuracy "\n"
"Test precision: " test-precision "\n"
"Test recall: " test-recall "\n"
"Test F1: " test-f-beta "\n\n"))
(when (> test-f-beta (:score @high-score*))
(reset! high-score* {:score test-f-beta})
(save network))
[network optimizer]))
[network (:optimizer params)]
(range (:epoch-count params)))
(println "Done.")
(log (str "Best score: " (:score @high-score*)))))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment