Created
February 12, 2020 11:33
-
-
Save Toekan/0d180f129c3bd3036a041149f87ac85e to your computer and use it in GitHub Desktop.
Training simple neural net with one hidden layer for the XOR problem using Clojure's MXNET interface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
; Toy example that trains a simple neural net with one hidden layer to the XOR | |
; problem: https://www.deeplearningbook.org/contents/mlp.html | |
(ns my-mxnet.core | |
(:require [org.apache.clojure-mxnet.eval-metric :as eval-metric] | |
[org.apache.clojure-mxnet.io :as mx-io] | |
[org.apache.clojure-mxnet.module :as m] | |
[org.apache.clojure-mxnet.optimizer :as optimizer] | |
[org.apache.clojure-mxnet.symbol :as sym] | |
[org.apache.clojure-mxnet.symbol-api :as sym-api] | |
[org.apache.clojure-mxnet.ndarray :as ndarray])) | |
(def X-values (ndarray/->ndarray [[0.0 0.0] [0.0 1.0] [1.0 0.0] [1.0 1.0]])) | |
(def y-values (ndarray/array [0.0 1.0 1.0 0.0] [4 1])) | |
(def train-data (mx-io/ndarray-iter [X-values] | |
{:label [y-values] :data-batch-size 4 | |
:shuffle false :last-batch-handle "pad" | |
:label-name "output_label"})) | |
(defn get-symbol-xor [] | |
(as-> (sym/variable "data") data | |
(sym-api/fully-connected {:name "layer1" :data data :num-hidden 10}) | |
(sym-api/activation {:name "relu1" :data data :act-type "relu"}) | |
(sym-api/fully-connected {:name "layer2" :data data :num-hidden 1}) | |
(sym-api/linear-regression-output {:name "output" :data data}))) | |
(defn train-and-test-model [] | |
(let [mod (m/module (get-symbol-xor) {:label-names ["output_label"]})] | |
(m/fit mod {:train-data train-data | |
:num-epoch 1000 | |
:fit-params (m/fit-params | |
{:optimizer (optimizer/sgd {:learning-rate 0.1 | |
:momentum 0.9 | |
:wd 0.0001})})}) | |
; Predicts ~ [0 1 1 0] | |
(println (m/predict mod {:eval-data train-data})) | |
; Resulting in a rmse close to 0 | |
(println (m/score mod {:eval-data train-data | |
:eval-metric (eval-metric/mse)})) | |
mod)) | |
(def trained-model (train-and-test-model)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment