Skip to content

Instantly share code, notes, and snippets.

@mhagita
Last active March 20, 2017 09:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mhagita/7f0b4f1c9036e4511e8ed183333d9585 to your computer and use it in GitHub Desktop.
Save mhagita/7f0b4f1c9036e4511e8ed183333d9585 to your computer and use it in GitHub Desktop.
R&ニューラルネットワークで未来の注文を予測 ref: http://qiita.com/mhagita/items/6afb7e2edd038b32d37d
r1 r2 r3
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
4 2 1
3 1 2
1 2 3
2 3 1
3 1 2
2 1 3
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
4 2 1
3 1 2
1 2 3
2 3 1
3 1 2
2 1 3
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
4 2 1
3 1 2
1 2 3
2 3 1
3 1 2
2 1 3
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 3 2
2 1 3
3 2 1
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
1 2 3
2 3 1
3 1 2
We can make this file beautiful and searchable if this error is corrected: It looks like row 7 should actually have 3 columns, instead of 1. in line 6.
r1,r2,r3
1,2,3
2,3,1
3,1,2
1,3,2
2,1,3
...
# nnetの読み込み
> library('nnet')
# データ読み込み
> df <- read.csv("data.csv", header = TRUE, colClasses=c("factor", "factor", "factor"))
# 訓練データ(1-300の範囲で2ずつ抽出するベクトルを作成)
> train <- seq.int(1, 300, by=2)
# 訓練データ(1-300の範囲でtrainに存在しない要素のベクトルを作成)
> test <- setdiff(1:300, train)
# ネットワークの構築
> df.nn <- nnet(r3~., df[train,], size=3, decay=0.001)
## weights: 27
#initial value 165.407234
#iter 10 value 3.747046
#iter 20 value 1.153351
#iter 30 value 0.817647
#iter 40 value 0.724959
#iter 50 value 0.645631
#iter 60 value 0.607413
#iter 70 value 0.561134
#iter 80 value 0.529012
#iter 90 value 0.478941
#iter 100 value 0.461648
#final value 0.461648
#stopped after 100 iterations
# ネットワーク内容の表示
> summary(df.nn)
#a 4-3-3 network with 27 weights
#options were - softmax modelling decay=0.001
# b->h1 i1->h1 i2->h1 i3->h1 i4->h1
# 1.79 -4.61 0.69 -4.73 0.32
# b->h2 i1->h2 i2->h2 i3->h2 i4->h2
# 1.98 0.20 -4.58 0.85 -4.21
# b->h3 i1->h3 i2->h3 i3->h3 i4->h3
# -4.32 2.62 3.34 2.48 3.51
# b->o1 h1->o1 h2->o1 h3->o1
# 1.31 -5.92 -5.02 6.47
# b->o2 h1->o2 h2->o2 h3->o2
# -0.88 8.26 -2.43 -2.45
# b->o3 h1->o3 h2->o3 h3->o3
# -0.45 -2.33 7.47 -4.00
> str(df.nn)
#List of 20
# $ n : num [1:3] 4 3 3
# $ nunits : int 11
# $ nconn : num [1:12] 0 0 0 0 0 0 5 10 15 19 ...
# $ conn : num [1:27] 0 1 2 3 4 0 1 2 3 4 ...
# $ nsunits : num 8
# $ decay : num 0.001
# $ entropy : logi FALSE
# $ softmax : logi TRUE
# $ censored : logi FALSE
# $ value : num 0.462
# $ wts : num [1:27] 1.787 -4.609 0.694 -4.732 0.319 ...
# $ convergence : int 1
# $ fitted.values: num [1:150, 1:3] 0.000154 0.000181 0.000322 0.000154 0.000181 ...
# ..- attr(*, "dimnames")=List of 2
# .. ..$ : chr [1:150] "1" "3" "5" "7" ...
# .. ..$ : chr [1:3] "1" "2" "3"
# $ residuals : num [1:150, 1:3] -0.000154 -0.000181 -0.000322 -0.000154 -0.000181 ...
# ..- attr(*, "dimnames")=List of 2
# .. ..$ : chr [1:150] "1" "3" "5" "7" ...
# .. ..$ : chr [1:3] "1" "2" "3"
# $ lev : chr [1:3] "1" "2" "3"
# $ call : language nnet.formula(formula = r3 ~ ., data = df[train, ], size = 3, decay = 0.001)
# $ terms :Classes 'terms', 'formula' language r3 ~ r1 + r2
# .. ..- attr(*, "variables")= language list(r3, r1, r2)
# .. ..- attr(*, "factors")= int [1:3, 1:2] 0 1 0 0 0 1
# .. .. ..- attr(*, "dimnames")=List of 2
# .. .. .. ..$ : chr [1:3] "r3" "r1" "r2"
# .. .. .. ..$ : chr [1:2] "r1" "r2"
# .. ..- attr(*, "term.labels")= chr [1:2] "r1" "r2"
# .. ..- attr(*, "order")= int [1:2] 1 1
# .. ..- attr(*, "intercept")= int 1
# .. ..- attr(*, "response")= int 1
# .. ..- attr(*, ".Environment")=<environment: R_GlobalEnv>
# .. ..- attr(*, "predvars")= language list(r3, r1, r2)
# .. ..- attr(*, "dataClasses")= Named chr [1:3] "factor" "factor" "factor"
# .. .. ..- attr(*, "names")= chr [1:3] "r3" "r1" "r2"
# $ coefnames : chr [1:4] "r12" "r13" "r22" "r23"
# $ contrasts :List of 2
# ..$ r1: chr "contr.treatment"
# ..$ r2: chr "contr.treatment"
# $ xlevels :List of 2
# ..$ r1: chr [1:3] "1" "2" "3"
# ..$ r2: chr [1:3] "1" "2" "3"
# - attr(*, "class")= chr [1:2] "nnet.formula" "nnet"
# 予測実施
> df.pred <- predict(df.nn, df[test,], type="class")
# 一致度の確認(テストデータの答えに、学習結果を突き合わせ)
> table(df[test,3], df.pred)
# df.pred
# 1 2 3
# 1 49 0 0
# 2 0 50 0
# 3 0 0 51
#(正解)
# 新しい予測の実施(1→2→X1 と 2→3→X2で、それぞれXnがなんになるか)
> df.test2 <- data.frame(r1=c("1","2"), r2=c("2","3"))
> predict(df.nn, df.test2, type="class")
#[1] "3" "1"
#(正解)
> source("http://hosho.ees.hokudai.ac.jp/~kubo/log/2007/img07/plot.nn.txt")
> plot.nn(df.nn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment