Created
October 31, 2016 11:33
-
-
Save masatoi/98c53bd13b0c2ed96ec592b9d35f1f22 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
;;; -*- coding:utf-8; mode:lisp; -*- | |
(ql:quickload :mgl-user) | |
(in-package :mgl-user) | |
;;; データの読み込み | |
(ql:quickload :fare-csv) | |
(ql:quickload :parse-number) | |
(defparameter data-list | |
(mapcar (lambda (line) | |
(mapcar #'parse-number:parse-number line)) | |
(fare-csv:read-csv-file "/home/wiz/tmp/mgl-logistic-data.txt"))) | |
(defparameter dataset | |
(let ((dataset (make-array (length data-list)))) | |
(loop for i from 0 to (1- (length data-list)) | |
for line in data-list | |
do | |
(setf (aref dataset i) | |
(make-datum | |
:id i :label (nth 2 line) | |
:array (make-mat 2 :initial-contents (subseq line 0 2))))) | |
dataset)) | |
;;; データの正規化 | |
(defparameter dataset-normal (copy-dataset dataset)) | |
(dataset-normalize! dataset-normal) | |
;;; データの可視化 | |
(ql:quickload :clgplot) | |
(defun plot-dataset (dataset) | |
(let ((positive-data (remove-if-not (lambda (datum) (= (datum-label datum) 1)) dataset)) | |
(negative-data (remove-if-not (lambda (datum) (= (datum-label datum) 0)) dataset))) | |
(clgp:plot-lists | |
(list (loop for datum across positive-data collect (mref (datum-array datum) 1)) | |
(loop for datum across negative-data collect (mref (datum-array datum) 1))) | |
:x-lists (list (loop for datum across positive-data collect (mref (datum-array datum) 0)) | |
(loop for datum across negative-data collect (mref (datum-array datum) 0))) | |
:style 'points))) | |
;;; モデル定義 | |
(defparameter fnn-sigmoid | |
(build-fnn (:class 'fnn :max-n-stripes 100) | |
(inputs (->input :size 2)) | |
(f1 (->sigmoid inputs)) | |
(prediction (->softmax-xe-loss (->activation f1 :name 'prediction :size 2) :name 'prediction)))) | |
;;; 学習実行 | |
(train-fnn-process-with-monitor fnn-sigmoid dataset-normal dataset-normal :n-epochs 100) | |
;; 2016-10-31 18:35:40: --------------------------------------------------- | |
;; 2016-10-31 18:35:40: training at n-instances: 10000 | |
;; 2016-10-31 18:35:40: test at n-instances: 10000 | |
;; 2016-10-31 18:35:40: pred. train bpn PREDICTION acc.: 91.00% (10000) | |
;; 2016-10-31 18:35:40: pred. train bpn PREDICTION xent: 3.456e-3 (10000) | |
;; 2016-10-31 18:35:40: pred. test bpn PREDICTION acc.: 91.00% (100) | |
;; 2016-10-31 18:35:40: pred. test bpn PREDICTION xent: 3.456e-3 (100) | |
;; 2016-10-31 18:35:40: Foreign memory usage: | |
;; foreign arrays: 0 (used bytes: 0) | |
;; CUDA memory usage: | |
;; device arrays: 114 (used bytes: 400,112, pooled bytes: 0) | |
;; host arrays: 0 (used bytes: 0) | |
;; host->device copies: 202, device->host copies: 20,602 | |
;; 2016-10-31 18:35:40: --------------------------------------------------- | |
;;; 重みを見てみる | |
(let* ((f1-activation (aref (clumps fnn-sigmoid) 2)) | |
(bias (aref (clumps f1-activation) 0)) | |
(weight (aref (clumps f1-activation) 2))) | |
(describe bias) | |
(describe weight) | |
(list bias weight)) | |
;; #<->WEIGHT (:BIAS PREDICTION) :SIZE 2 1/1 :NORM 2.76652> | |
;; [standard-object] | |
;; Slots with :INSTANCE allocation: | |
;; NAME = (:BIAS PREDICTION) | |
;; SIZE = 2 | |
;; NODES = #<MAT 1x2 AF #2A((1.956225 -1.9562262))> | |
;; DERIVATIVES = #<MAT 1x2 A #2A((0.0 0.0))> | |
;; DEFAULT-VALUE = 0 | |
;; SHARED-WITH-CLUMP = NIL | |
;; DIMENSIONS = (1 2) | |
;; #<->WEIGHT (F1 PREDICTION) :SIZE 4 1/1 :NORM 4.55551> | |
;; [standard-object] | |
;; Slots with :INSTANCE allocation: | |
;; NAME = (F1 PREDICTION) | |
;; SIZE = 4 | |
;; NODES = #<MAT 2x2 AF #2A((-2.4052894 2.4102905) (-2.1375985 2.1420684))> | |
;; DERIVATIVES = #<MAT 2x2 A #2A((0.0 0.0) (0.0 0.0))> | |
;; DEFAULT-VALUE = 0 | |
;; SHARED-WITH-CLUMP = NIL | |
;; DIMENSIONS = (2 2) | |
;; (#<->WEIGHT (:BIAS PREDICTION) :SIZE 2 1/1 :NORM 2.76652> | |
;; #<->WEIGHT (F1 PREDICTION) :SIZE 4 1/1 :NORM 4.55551>) | |
(defun plot-prediction (dataset fnn class) | |
(let* ((min-x1 (loop for x across dataset minimize (mref (datum-array x) 0))) | |
(max-x1 (loop for x across dataset maximize (mref (datum-array x) 0))) | |
(min-x2 (loop for x across dataset minimize (mref (datum-array x) 1))) | |
(max-x2 (loop for x across dataset maximize (mref (datum-array x) 1))) | |
(x1-list (loop for x from min-x1 to max-x1 by 0.1 collect x)) | |
(x2-list (loop for x from min-x2 to max-x2 by 0.1 collect x))) | |
(clgp:splot-list | |
(lambda (x1 x2) | |
(mref (predict-datum fnn | |
(make-datum :id 0 :label 0d0 | |
:array (make-mat 2 :initial-contents (list x1 x2)))) | |
class)) | |
x1-list x2-list :map t))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment