Skip to content

Instantly share code, notes, and snippets.

@Toekan
Created February 12, 2020 11:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Toekan/0d180f129c3bd3036a041149f87ac85e to your computer and use it in GitHub Desktop.
Save Toekan/0d180f129c3bd3036a041149f87ac85e to your computer and use it in GitHub Desktop.
Training simple neural net with one hidden layer for the XOR problem using Clojure's MXNET interface
; Toy example that trains a simple neural net with one hidden layer to the XOR
; problem: https://www.deeplearningbook.org/contents/mlp.html
(ns my-mxnet.core
(:require [org.apache.clojure-mxnet.eval-metric :as eval-metric]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.module :as m]
[org.apache.clojure-mxnet.optimizer :as optimizer]
[org.apache.clojure-mxnet.symbol :as sym]
[org.apache.clojure-mxnet.symbol-api :as sym-api]
[org.apache.clojure-mxnet.ndarray :as ndarray]))
(def X-values (ndarray/->ndarray [[0.0 0.0] [0.0 1.0] [1.0 0.0] [1.0 1.0]]))
(def y-values (ndarray/array [0.0 1.0 1.0 0.0] [4 1]))
(def train-data (mx-io/ndarray-iter [X-values]
{:label [y-values] :data-batch-size 4
:shuffle false :last-batch-handle "pad"
:label-name "output_label"}))
(defn get-symbol-xor []
(as-> (sym/variable "data") data
(sym-api/fully-connected {:name "layer1" :data data :num-hidden 10})
(sym-api/activation {:name "relu1" :data data :act-type "relu"})
(sym-api/fully-connected {:name "layer2" :data data :num-hidden 1})
(sym-api/linear-regression-output {:name "output" :data data})))
(defn train-and-test-model []
(let [mod (m/module (get-symbol-xor) {:label-names ["output_label"]})]
(m/fit mod {:train-data train-data
:num-epoch 1000
:fit-params (m/fit-params
{:optimizer (optimizer/sgd {:learning-rate 0.1
:momentum 0.9
:wd 0.0001})})})
; Predicts ~ [0 1 1 0]
(println (m/predict mod {:eval-data train-data}))
; Resulting in a rmse close to 0
(println (m/score mod {:eval-data train-data
:eval-metric (eval-metric/mse)}))
mod))
(def trained-model (train-and-test-model))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment