Skip to content

Instantly share code, notes, and snippets.

@nqbao
Created June 25, 2016 17:11
Show Gist options
  • Save nqbao/6540c4ac6b68e1e6ce0cfe2b9884b098 to your computer and use it in GitHub Desktop.
Save nqbao/6540c4ac6b68e1e6ce0cfe2b9884b098 to your computer and use it in GitHub Desktop.
library("nnet")
splitdf <- function(dataframe, ratio=0.8, seed=NULL) {
if (!is.null(seed)) set.seed(seed)
index <- 1:nrow(dataframe)
trainindex = sample(1:nrow(dataframe), size=ratio*nrow(dataframe))
trainset <- dataframe[trainindex, ]
testset <- dataframe[-trainindex, ]
list(train=trainset,test=testset)
}
split = splitdf(iris)
m <- multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, data=split$train)
calculate_accuracy = function(predicted, actual) {
d = table(predicted, actual)
sum(diag(d))/sum(d)
}
predicted = predict(m, split$test)
calculate_accuracy(predicted, split$test$Species)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment