Skip to content

Instantly share code, notes, and snippets.

@vbkaisetsu
Created December 27, 2017 14:39
Show Gist options
  • Save vbkaisetsu/492efc3f4b8757562d16f8cc03de14f5 to your computer and use it in GitHub Desktop.
Save vbkaisetsu/492efc3f4b8757562d16f8cc03de14f5 to your computer and use it in GitHub Desktop.
primitiv XOR example (Clojure)
(ns xor-example.core
(:import
[primitiv Device Graph Parameter Shape]
[primitiv functions functions$batch]
[primitiv.devices Naive]
[primitiv.initializers XavierUniform Constant]
[primitiv.optimizers SGD]))
(defn -main []
(let [
input_data [
1 1
1 -1
-1 1
-1 -1
]
output_data [
1
-1
-1
1
]
dev (Naive.)
g (Graph.)
]
(Device/set_default dev)
(Graph/set_default g)
(let [
pw1 (Parameter. (Shape. [8 2]) (XavierUniform.))
pb1 (Parameter. (Shape. [8]) (Constant. 0))
pw2 (Parameter. (Shape. [1 8]) (XavierUniform.))
pb2 (Parameter. (Shape. []) (Constant. 0))
optimizer (SGD. 0.1)
]
(.add optimizer [pw1 pb1 pw2 pb2])
(doseq [i (range 10)]
(.clear g)
(let [
x (functions/input (Shape. [2] 4) input_data)
w1 (functions/parameter pw1)
b1 (functions/parameter pb1)
w2 (functions/parameter pw2)
b2 (functions/parameter pb2)
h (functions/tanh (functions/add (functions/matmul w1 x) b1))
y (functions/add (functions/matmul w2 h) b2)
y_val (.to_array y)
t (functions/input (Shape. [] 4) output_data)
diff (functions/subtract t y)
loss (functions$batch/mean (functions/multiply diff diff))
loss_val (.to_float loss)
]
(println "epoch " i ":")
(doseq [[j y] (map-indexed vector y_val)]
(println " [" j "]: " y))
(println " loss: " loss_val)
(.reset_gradients optimizer)
(.backward loss)
(.update optimizer))))))
;; epoch 0 :
;; [ 0 ]: -0.51549953
;; [ 1 ]: 0.7876314
;; [ 2 ]: -0.7876314
;; [ 3 ]: 0.51549953
;; loss: 1.4430515
;;
;; ~~~~
;;
;; ~~~~
;;
;; epoch 8 :
;; [ 0 ]: 0.06495705
;; [ 1 ]: -0.05289926
;; [ 2 ]: -0.056721658
;; [ 3 ]: 0.05311704
;; loss: 0.88941664
;; epoch 9 :
;; [ 0 ]: 0.07288616
;; [ 1 ]: -0.072105706
;; [ 2 ]: -0.05547892
;; [ 3 ]: 0.06541988
;; loss: 0.871522
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment