Skip to content

Instantly share code, notes, and snippets.

@yogthos
Created April 1, 2011 20:22
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yogthos/898782 to your computer and use it in GitHub Desktop.
Save yogthos/898782 to your computer and use it in GitHub Desktop.
backprapagation neural net
(ns nn)
(defstruct network :ai :ah :ao :wi :wo :ci :co)
(defn rand-in-range [a b]
(+ (* (- b a) (rand)) a))
(defn make-matrix
([i j] (make-matrix i j 0.0))
([i j fill] (repeat i (repeat j fill))))
(defn sigmoid [x]
(Math/tanh x))
(defn dsigmoid [x]
(- 1.0 (* x x)))
(defn randomize-matrix [m]
(for [row m]
(map (fn [_] (rand-in-range -0.2 0.2)) row)))
(defn make-network
"number of inputs, hidden, output"
[inputs nh no]
(let [ni (inc inputs)
;;activations
ai (repeat ni 1.0)
ah (repeat nh 1.0)
ao (repeat no 1.0)
;;weights
wi (randomize-matrix (make-matrix ni nh))
wo (randomize-matrix (make-matrix nh no))
;;change in weights for momentum
ci (make-matrix ni nh)
co (make-matrix nh no)]
(struct network ai ah ao wi wo ci co)))
(defn update-activations [activations weights]
(->> (map (fn [a x] (map #(* a %) x))
activations weights)
(reduce (fn [sums vals] (map + sums vals)))
(map sigmoid)))
(defn update [network inputs]
(assert (= (dec (count (:ai network))) (count inputs)))
(let [ai (conj inputs (last (:ai network))) ;;input activations
ah (update-activations ai (:wi network))
ao (update-activations ah (:wo network))]
(assoc network :ai ai, :ah ah, :ao ao)))
(defn back-propagate [network targets N M error]
(assert (= (count (:ao network)) (count targets)))
(let [output-deltas (map (fn [target a]
(* (dsigmoid a) (- target a)))
targets (:ao network))
hidden-deltas (map (fn [w a]
(* (dsigmoid a) (reduce + (map * w output-deltas))))
(:wo network) (:ah network))
co (for [a (:ah network)]
(map (fn [delta] (* delta a)) output-deltas))
ci (for [a (:ai network)]
(map (fn [delta] (* delta a)) hidden-deltas))
wo (map (fn [weights changes new-changes]
(doall (map (fn [w oc nc] (+ w (* N nc) (* M oc)))
weights changes new-changes)))
(:wo network) (:co network) co)
wi (map (fn [weights changes new-changes]
(doall (map (fn [w oc nc] (+ w (* N nc) (* M oc)))
weights changes new-changes)))
(:wi network) (:ci network) ci)]
[(assoc network :wo wo, :co co, :wi wi :ci ci),
(+ error (reduce +
(map
(fn [target a] (* 0.5 (Math/pow (- target a) 2)))
targets (:ao network))))]))
(defn train [network, patterns, iterations, N, M]
(loop [cur-network network
iteration 0]
(if (< iteration iterations)
(let [[new-network, error]
(reduce (fn [[network, error] [inputs targets]]
(back-propagate (update network inputs) targets N M error))
[cur-network, 0.0] patterns)]
(println "iteration:" iteration ", error:" error)
(recur new-network (inc iteration)))
cur-network)))
(defn test-network [network patterns]
(doseq [[p0 _] patterns]
(println p0 "->" (:ao (update network p0)))))
(defn demo []
(let [pat [[[0 0] [0]]
[[0 1] [1]]
[[1 0] [1]]
[[1 1] [0]]]]
(-> (make-network 2 2 1)
(train pat 1000 0.5 0.1)
(test-network pat))))
(time (demo))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment