Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
# **********************************
# Author: Benjamin Tovar
# Date: March 25, 2015
# Post: Using Neural Networks to fit equations in R
# Post url: http://btovar.com/2015/03/neural-networks-fit-eq/
# **********************************
# load libraries
library(RSNNS
library(ggplot2)
# set function
f <- function(x) {
# this is the equation we will
# try to fit using a Neural Network
y <- exp(-x/10)*sin(x)
return(y)
}
# set the input
x <- seq(from=0,to=20,by=0.1)
# compute the output
y <- f(x)
dataset <- data.frame(x=x,y=y)
# **************
# plot the function
# **************
ggplot(dataset) +
aes(x=x,y=y) +
geom_point(colour="#ff6f69") +
geom_line(colour="#ff6f69") +
ylim(c(-1,1)) +
xlim(c(0,20)) +
labs(title="Equation to fit: f(x) = exp(-x/10)*sin(x)")
# ********************************************
# train the NN model
# use the values of x to estimate values of y
# ********************************************
model <- mlp(dataset$x, dataset$y, size=15,
maxit=10000, linOut=TRUE,
learnFuncParams=c(0.01, 0),
hiddenActFunc="Act_TanH")
# details about the model
# summary(model)
# model
# weightMatrix(model)
# extractNetInfo(model)
# *****************
# plot the training error of
# *****************
plotIterativeError(model,main="Training error",col="#ff6f69",ylim=c(0,25));grid()
# ***********************
# use the trained model
# to predict values of y
# ***********************
# predict y
y_pred <- predict(model,as.matrix(dataset$x))
# compute the length of predictions
l <- length(y_pred)
# ****************************************
# set the format of the dataset for ggplot
# *****************************************
# add column of predicted y to the dataset
dataset <- cbind(dataset,y_pred)
# option 1: constructing the data.frame by hand
dataset_2 <- data.frame(x=c(dataset$x,dataset$x),
value=c(y,y_pred),
variable=c(rep("y",l),rep("y_pred",l)))
# option 2: using reshape library and melt function
library(reshape)
dataset_2 <- melt(dataset,id=c("x"))
# ****************************************
# plot the results (y and predicted y)
# *****************************************
ggplot(dataset_2) +
aes(x=x,y=value) +
geom_point(aes(colour=variable,fill=variable,shape=variable),alpha=0.6) +
geom_line(aes(colour=variable),alpha=0.4) +
ylim(c(-1,1)) +
xlim(c(0,20)) +
labs(title="Use of package RSNNS in R to fit an equation",y="y")

Is line 9 missing a ) ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment