Skip to content

Instantly share code, notes, and snippets.

@badbye
Last active April 21, 2016 02:42
Show Gist options
  • Save badbye/d2f8830c4ab7eac80d8516e649e43733 to your computer and use it in GitHub Desktop.
Save badbye/d2f8830c4ab7eac80d8516e649e43733 to your computer and use it in GitHub Desktop.
求调试
## nnet
library(nnet)
n = nnet(Species ~ ., data = iris, size=2, maxit = 500, abstol=1e-6)
mean(predict(n, iris[, 1:4], type='class') == iris[, 5])
# 0.66 or 0.99
library(devtools)
source_url('https://gist.githubusercontent.com/fawda123/7471137/raw/466c1474d0a505ff044412703516c34f1a4684a5/nnet_plot_update.r')
plot.nnet(n)
## mxnet
library(mxnet)
index = sample(1:150, 150)
# x_train = data.matrix(iris[index , 1:4])
x_train = scale(data.matrix(iris[index , 1:4]))
y_train = as.numeric(iris[index, 5])
d = mx.symbol.Variable('data')
h1 = mx.symbol.FullyConnected(data=d, num.hidden=2)
h1 = mx.symbol.Activation(data=h1, act.type='sigmoid') # use 'relu' or 'tanh' does not help
h2 = mx.symbol.FullyConnected(data=h1, num.hidden=3)
res = mx.symbol.SoftmaxOutput(h2)
mx.set.seed(0)
model <- mx.model.FeedForward.create(res,
X=x_train, y=y_train,
ctx=mx.cpu(),
num.round=500, array.batch.size=150,
learning.rate=1e-6, momentum=0.9,
eval.metric=mx.metric.accuracy)
mean(max.col(t(predict(model, x_train))) == y_train)
# always 0.33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment