Last active
January 26, 2016 08:49
-
-
Save Jah524/23538f678d7b1441514e to your computer and use it in GitHub Desktop.
Clojureで0からのニューラルネット構築(フィードフォワードモデル)と隠れ層の挙動調査
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns feedforward | |
(:require [incanter.core :refer [exp sin view]] | |
[incanter.charts :refer [function-plot add-function set-stroke-color]])) | |
(defn sigmoid [x] | |
(/ 1 (+ 1 (exp (- x))))) | |
(defn unit-output [input-list w-list bias activate-fn-key] | |
(let [activate-fn (condp = activate-fn-key | |
:sigmoid sigmoid | |
:linear identity)] | |
(->> (mapv * input-list w-list) | |
(cons bias) | |
(reduce +) | |
activate-fn))) | |
(defn network-output [w-network x-list] | |
(loop [w-network w-network, input-list x-list, acc [x-list]] | |
(if-let [layer (first w-network)] | |
(let [{activate-fn :activate-fn units :units} layer | |
output-list (map (fn [{bias :bias w-list :w-list}] | |
(unit-output input-list w-list bias activate-fn)) | |
units)] | |
(recur (rest w-network) output-list (cons output-list acc))) | |
(reverse acc)))) | |
(defn derivative-value [unit-output activate-fn] | |
(condp = activate-fn | |
:sigmoid (* (sigmoid unit-output) (- 1 (sigmoid unit-output))) | |
:linear 1)) | |
(defn back-propagation [w-network training-x training-y learning-rate] | |
(let [reversed-w-network (reverse w-network) | |
reversed-output-net (reverse (network-output w-network training-x))] | |
(loop [reversed-w-network reversed-w-network | |
reversed-output-net reversed-output-net | |
delta-list (mapv #(* (- %2 %1) | |
(derivative-value %2 (:activate-fn (first reversed-w-network)))) | |
training-y (first reversed-output-net)) | |
acc []] | |
(if-let [w-layer (first reversed-w-network)] | |
(let [output-layer (first reversed-output-net) | |
input-layer (first (rest reversed-output-net)) | |
updated-w-list {:units (map (fn [{bias :bias w-list :w-list} delta] | |
{:w-list (map (fn [w input] | |
(- w (* learning-rate delta input))) | |
w-list input-layer) | |
:bias (- bias (* learning-rate delta))}) | |
(:units w-layer) delta-list) | |
:activate-fn (:activate-fn w-layer)}] | |
(recur (rest reversed-w-network) | |
(rest reversed-output-net) | |
(map-indexed (fn [index unit-output] | |
(let [connected-w-list (map #(nth (:w-list %) index) (:units w-layer))] | |
(* (->> (mapv #(* %1 %2) delta-list connected-w-list) | |
(reduce +)) | |
(derivative-value unit-output (:activate-fn (first (rest reversed-w-network))))))) | |
input-layer) | |
(cons updated-w-list acc))) | |
acc)))) | |
(defn init-w-network [network-info] | |
(loop [network-info network-info, acc []] | |
(if-let [layer-info (first (rest network-info))] | |
(let [{n :unit-num a :activate-fn} layer-info | |
{bottom-leyer-n :unit-num} (first network-info)] | |
(recur (rest network-info) | |
(cons {:activate-fn a | |
:units (repeatedly n (fn [] {:bias (rand) :w-list (repeatedly bottom-leyer-n rand)}))} acc))) | |
(reverse acc)))) | |
(defn train [w-network training-list learning-rate] | |
(loop [w-network w-network, training-list training-list] | |
(if-let [training (first training-list)] | |
(recur (back-propagation w-network (:training-x training) (:training-y training) learning-rate) (rest training-list)) | |
w-network))) | |
(defn sum-of-squares-error | |
[w-network training-list] | |
(loop [training-list training-list, acc 0] | |
(let [{training-x :training-x training-y :training-y} (first training-list)] | |
(if (and training-x training-y) | |
(let [output-layer (first (reverse (network-output w-network training-x))) | |
error (->> (mapv #(* 0.5 (- %1 %2) (- %1 %2)) output-layer training-y) | |
(reduce +))] | |
(recur (rest training-list) (+ error acc))) | |
acc)))) | |
(defn training-loop [w-network training-list learning-rate epoc] | |
(loop [w-network w-network, epoc epoc] | |
(if (> epoc 0) | |
(let [w-network (train w-network (shuffle training-list) learning-rate) | |
error (sum-of-squares-error w-network training-list)] | |
(println (str "epoc=> " epoc "\nw-network=> " w-network "\nerror=> " error"\n")) | |
(recur w-network (dec epoc))) | |
w-network))) | |
;;;; | |
(def training-list-sin3 (map (fn[x]{:training-x [x] :training-y [(sin x)]}) (range -3 3 0.2))) | |
(def training-list-sin10 (map (fn[x]{:training-x [x] :training-y [(sin x)]}) (range -10 10 0.2))) | |
(defn sample-sin-3 [] | |
(let [hidden-num 3 | |
w-network (training-loop (init-w-network [{:unit-num 1 :activate-fn :linear} | |
{:unit-num hidden-num :activate-fn :sigmoid} | |
{:unit-num 1 :activate-fn :linear}]) training-list-sin3 0.05 10000) | |
nn-plot (-> (function-plot sin -3 3) | |
(add-function #(first (last (network-output w-network [%]))) -3 3))] | |
(loop [counter-list (range hidden-num), nn-plot nn-plot] | |
(if-let [counter (first counter-list)] | |
(let [nn-plot (-> nn-plot | |
(add-function #(nth (second (network-output w-network [%])) counter) -3 3) | |
(set-stroke-color java.awt.Color/gray :dataset (+ 2 counter)))] | |
(recur (rest counter-list) nn-plot)) | |
(view nn-plot))))) | |
(defn sample-bad-sin-10 [] | |
(let [hidden-num 3 | |
w-network (training-loop (init-w-network [{:unit-num 1 :activate-fn :linear} | |
{:unit-num hidden-num :activate-fn :sigmoid} | |
{:unit-num 1 :activate-fn :linear}]) training-list-sin10 0.05 10000) | |
nn-plot (-> (function-plot sin -10 10) | |
(add-function #(first (last (network-output w-network [%]))) -10 10))] | |
(loop [counter-list (range hidden-num), nn-plot nn-plot] | |
(if-let [counter (first counter-list)] | |
(let [nn-plot (-> nn-plot | |
(add-function #(nth (second (network-output w-network [%])) counter) -10 10) | |
(set-stroke-color java.awt.Color/gray :dataset (+ 2 counter)))] | |
(recur (rest counter-list) nn-plot)) | |
(view nn-plot))))) | |
(defn sample-sin-10 [] | |
(let [hidden-num 10 | |
w-network (training-loop (init-w-network [{:unit-num 1 :activate-fn :linear} | |
{:unit-num hidden-num :activate-fn :sigmoid} | |
{:unit-num 1 :activate-fn :linear}]) training-list-sin10 0.05 10000) | |
nn-plot (-> (function-plot sin -10 10) | |
(add-function #(first (last (network-output w-network [%]))) -10 10))] | |
(loop [counter-list (range hidden-num), nn-plot nn-plot] | |
(if-let [counter (first counter-list)] | |
(let [nn-plot (-> nn-plot | |
(add-function #(nth (second (network-output w-network [%])) counter) -10 10) | |
(set-stroke-color java.awt.Color/gray :dataset (+ 2 counter)))] | |
(recur (rest counter-list) nn-plot)) | |
(view nn-plot))))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment