Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Neural Network Sort, generation of a learning curve for sorting 4 numbers.
# See also https://gist.github.com/primaryobjects/3b41f8b2f122eb16a65b
library(neuralnet)
library(ggplot2)
library(reshape2)
# Helper method to generate a training set containing size random numbers (a, b, c) and sorted (x, y, z).
generateSet <- function(size = 100, max = 100) {
# Generate size random numbers between 1 and max.
training <- data.frame(a=sample(1:max, size, replace=TRUE),
b=sample(1:max, size, replace=TRUE),
c=sample(1:max, size, replace=TRUE),
d=sample(1:max, size, replace=TRUE))
# Generate output examples by sorting the numbers.
output <- data.frame()
x <- sapply(1:nrow(training), function(i) {
row <- training[i, ]
sorted <- row[order(row)]
output <<- rbind(output, unlist(sorted))
})
# Append output to the training set.
names(output) <- c('w', 'x', 'y', 'z')
cbind(training, output)
}
# Helper method to restore the original values after scaling. x is the object to unscale, orig is the originally scaled data.
unscale <- function(x, orig) {
t(apply(x, 1, function(r) {
r * attr(orig, 'scaled:scale') + attr(orig, 'scaled:center')
}))
}
nnsort <- function(fit, scaleVal, a, b, c, d) {
numbers <- data.frame(a=a, b=b, c=c, d=d, w=0, x=0, y=0, z=0)
numbersScaled <- as.data.frame(scale(numbers, attr(scaleVal, 'scaled:center'), attr(scaleVal, 'scaled:scale')))
round(unscale(compute(fit, numbersScaled[,1:4])$net.result, scaleVal))[,5:8]
}
results <- data.frame()
for (i in 1:30) {
# Generate training and cv data.
data <- generateSet(i*50, 50)
# Normalize data.
data <- scale(data)
# Split for a training and cv set.
half <- nrow(data)/2
training <- data[1:half,]
cv <- data[(half+1):nrow(data),]
# Train neural network.
fit <- neuralnet(w + x + y + z ~ a + b + c + d,
training,
hidden = c(40, 40),
threshold = 0.01,
rep=1,
learningrate = 0.6,
stepmax = 9999999,
lifesign = 'full')
# Check results.
results1 <- round(unscale(compute(fit, training[,1:4])$net.result, data))
results2 <- round(unscale(compute(fit, cv[,1:4])$net.result, data))
# Count rows that are correct. Note, we use round(i, 10) when comparing equality http://stackoverflow.com/a/18668681.
correct1 <- length(which(rowSums(round(unscale(training, data)[,5:8], 10) == results1[,5:8]) == 4))
correct2 <- length(which(rowSums(round(unscale(cv, data)[,5:8], 10) == results2[,5:8]) == 4))
# Record accuracy history.
results <- rbind(results, c(correct1 / nrow(training), correct2 / nrow(cv)))
# Plot learning curve.
names(results) <- c('Train', 'CV')
r <- melt(results)
r <- cbind(r, seq(from = 25, to = nrow(results) * 25, by = 25))
names(r) <- c('Set', 'Accuracy', 'Count')
print(ggplot(data = r, aes(x = Count, y = Accuracy, colour = Set)) + geom_line())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.