Skip to content

Instantly share code, notes, and snippets.

@ankkhedia
Last active January 17, 2019 22:45
Show Gist options
  • Save ankkhedia/92adcc83524963a26a8acdced9b30d78 to your computer and use it in GitHub Desktop.
Save ankkhedia/92adcc83524963a26a8acdced9b30d78 to your computer and use it in GitHub Desktop.
## Loading required packages
library("readr")
library("dplyr")
library("mxnet")
library("abind")
mx.set.seed(1234)
## Preprocessing steps
## generating synthetic data
len<-44000
df <- data.frame(dim1= numeric(len),dim2= numeric(len),dim3= numeric(len),dim4= numeric(len),dim5= numeric(len),dim6= numeric(len),dim7= numeric(len))
df$dim1<- sin(pi/12 * (1:(len)))
df$dim2<- sin(5 + pi/6 * (1:(len)))
df$dim3<- sin(15 + pi/8 * (1:(len)))
df$dim4<- sin(55 + pi/10 * (1:(len)))
df$dim5<- sin(2 + pi/16 * (1:(len)))
df$dim6<- sin(1+ pi/7 * (1:(len)))
df$dim7 <- 7*df$dim1+ 5*df$dim2+ 17*df$dim3+ 25*df$dim4 + 19*df$dim5+ 21*df$dim6
## Now we normalise each of the feature set to a range(0,1)
df <- matrix(as.matrix(df),
ncol = ncol(df),
dimnames = NULL)
rangenorm <- function(x) {
(x - min(x))/(max(x) - min(x))
}
df <- apply(df, 2, rangenorm)
df <- t(df)
n_dim <- 7
seq_len <- 100
num_samples <- 430
# extract only required data from dataset
trX <- df[1:n_dim, 1:(seq_len * num_samples)]
# the label data(next output value) should be one time
# step ahead of the current output value
trY <- df[7, 2:(1+(seq_len * num_samples))]
## reshape the matrices in the format acceptable by MXNetR RNNs
trainX <- trX
dim(trainX) <- c(n_dim, seq_len, num_samples)
trainY <- trY
dim(trainY) <- c(seq_len, num_samples)
batch.size <- 32
# take first 300 samples for training - remaining 100 for evaluation
train_ids <- 1:300
eval_ids <- 301:400
## create dataiterators
train.data <- mx.io.arrayiter(data = trainX[, , train_ids, drop = F],
label = trainY[, train_ids],
batch.size = batch.size, shuffle = TRUE)
eval.data <- mx.io.arrayiter(data = trainX[, , eval_ids, drop = F],
label = trainY[, eval_ids],
batch.size = batch.size, shuffle = FALSE)
## Create the symbol for RNN
symbol <- rnn.graph(num_rnn_layer = 2,
num_hidden = 50,
input_size = NULL,
num_embed = NULL,
num_decode = 1,
masking = F,
loss_output = "linear",
dropout = 0.5,
ignore_label = -1,
cell_type = "lstm",
output_last_state = T,
config = "one-to-one")
mx.metric.mse.seq <- mx.metric.custom("MSE", function(label, pred) {
label = mx.nd.reshape(label, shape = -1)
pred = mx.nd.reshape(pred, shape = -1)
res <- mx.nd.mean(mx.nd.square(label - pred))
return(as.array(res))
})
ctx <- mx.cpu()
initializer <- mx.init.Xavier(rnd_type = "gaussian",
factor_type = "avg",
magnitude = 1)
optimizer <- mx.opt.create("adadelta",
rho = 0.9,
eps = 1e-06,
wd = 1e-06,
clip_gradient = 1,
rescale.grad = 1/batch.size)
logger <- mx.metric.logger()
epoch.end.callback <- mx.callback.log.train.metric(period = 10,
logger = logger)
## train the network
system.time(model <- mx.model.buckets(symbol = symbol,
train.data = train.data,
eval.data = eval.data,
num.round = 200,
ctx = ctx,
verbose = TRUE,
metric = mx.metric.mse.seq,
initializer = initializer,
optimizer = optimizer,
batch.end.callback = NULL,
epoch.end.callback = epoch.end.callback))
## We extract the state symbols for RNN
internals <- model$symbol$get.internals()
sym_state <- internals$get.output(which(internals$outputs %in% "RNN_state"))
sym_state_cell <- internals$get.output(which(internals$outputs %in% "RNN_state_cell"))
sym_output <- internals$get.output(which(internals$outputs %in% "loss_output"))
symbol <- mx.symbol.Group(sym_output, sym_state, sym_state_cell)
## We will predict 100 timestamps for 401st sample (first sample from the test samples)
pred_length <- 100
predicted <- numeric()
## We pass the 400th sample through the network to get the weights and use it for predicting next
## 100 time stamps.
data <- mx.nd.array(trainX[, , 400, drop = F])
label <- mx.nd.array(trainY[, 400, drop = F])
## We create dataiterators for the input, please note that the label is required to create
## iterator and will not be used in the inference. You can use dummy values too in the label.
infer.data <- mx.io.arrayiter(data = data,
label = label,
batch.size = 1,
shuffle = FALSE)
infer <- mx.infer.rnn.one(infer.data = infer.data,
symbol = symbol,
arg.params = model$arg.params,
aux.params = model$aux.params,
input.params = NULL,
ctx = ctx)
## Once we get the weights for the above time series, we try to predict the next 100 steps for
## this time series, which is technically our 401st time series. We will also show how to change
## the below code for doing auto-regressive inference.
actual <- trainY[, 401]
## Now we iterate one by one to generate each of the next timestamp pollution values
for (i in 1:pred_length) {
data <- mx.nd.array(trainX[, i, 401, drop = F])
label <- mx.nd.array(trainY[i, 401, drop = F])
infer.data <- mx.io.arrayiter(data = data,
label = label,
batch.size = 1,
shuffle = FALSE)
## note that we use rnn state values from previous iterations here
infer <- mx.infer.rnn.one(infer.data = infer.data,
symbol = symbol,
ctx = ctx,
arg.params = model$arg.params,
aux.params = model$aux.params,
input.params = list(rnn.state = infer[[2]],
rnn.state.cell = infer[[3]]))
pred <- infer[[1]]
predicted <- c(predicted, as.numeric(as.array(pred)))
}
## auto-regressive inference
## We will predict 100 timestamps for 401st sample (first sample from the test samples)
pred_length <- 100
predicted <- numeric()
## We pass the 400th sample through the network to get the weights and use it for predicting next
## 100 time stamps.
data <- mx.nd.array(trainX[, , 400, drop = F])
label <- mx.nd.array(trainY[, 400, drop = F])
infer.data <- mx.io.arrayiter(data = data,
label = label,
batch.size = 1,
shuffle = FALSE)
infer <- mx.infer.rnn.one(infer.data = infer.data,
symbol = symbol,
arg.params = model$arg.params,
aux.params = model$aux.params,
input.params = NULL,
ctx = ctx)
pred <- as.numeric(as.array(infer[[1]])[1,100])
actual <- trainY[, 401]
## Iterate one by one over timestamps
for (i in 1:pred_length) {
data_auto <- mx.nd.array(trainX[, i, 401, drop = F])
m <- as.array(data_auto)
m[1,,] <- as.numeric(as.array(pred))
data_auto <- mx.nd.array(m)
label <- mx.nd.array(trainY[i, 401, drop = F])
infer.data <- mx.io.arrayiter(data = data_auto,
label = label,
batch.size = 1,
shuffle = FALSE)
## use previous RNN state values
infer <- mx.infer.rnn.one(infer.data = infer.data,
symbol = symbol,
ctx = ctx,
arg.params = model$arg.params,
aux.params = model$aux.params,
input.params = list(rnn.state = infer[[2]],
rnn.state.cell = infer[[3]]))
pred <- infer[[1]]
predicted <- c(predicted, as.numeric(as.array(pred)))
}
## Creating plots
library(ggplot2)
library("reshape")
## Create plots of actual signal vs predicted signal
rl<- as.data.frame(actual)
pr<- as.data.frame(predicted)
pr$idx <- as.numeric(row.names(pr))
pr$actual<- rl$actual
melteddata <- melt(pr, id = 'idx')
ggplot(melteddata, aes(x = idx, y = value, colour = variable)) +
geom_line() + ylab(label="values") +
xlab("number of time steps") + scale_colour_manual(values=c("red", "blue"))
## create plots of validation vs training error
epoch<- seq(1,200)
loss<- as.data.frame(epoch)
loss$train_MSE <- logger$train
loss$validation_MSE <- logger$eval
melteddata <- melt(loss, id = 'epoch')
ggplot(melteddata, aes(x = epoch, y = value, colour = variable)) +
geom_line() + ylab(label="Mean Squared Error") + xlab("Epochs") +
scale_colour_manual(values=c("red", "blue"))
## visualise the data
len<-500
df <- data.frame(dim1= numeric(len),dim2= numeric(len),dim3= numeric(len),dim4= numeric(len),dim5= numeric(len),dim6= numeric(len),dim7= numeric(len))
df$dim1<- sin(pi/12 * (1:(len)))
df$dim2<- sin(5 + pi/6 * (1:(len)))
df$dim3<- sin(15 + pi/8 * (1:(len)))
df$dim4<- sin(55 + pi/10 * (1:(len)))
df$dim5<- sin(2 + pi/16 * (1:(len)))
df$dim6<- sin(1+ pi/7 * (1:(len)))
df$dim7 <- 7*df$dim1+ 5*df$dim2+ 17*df$dim3+ 25*df$dim4 + 19*df$dim5+ 21*df$dim6
rangenorm <- function(x) {
(x - min(x))/(max(x) - min(x))
}
max_val<-apply(df, 2, max)
min_val<-apply(df, 2, min)
df <- apply(df, 2, rangenorm)
df1<-as.data.frame(df)
df1$x<- seq(1:500)
library(cowplot)
A<-ggplot(df1, aes(x = df1$x , y = df1$dim1)) + geom_line(color= "blue") + ylab(label="Dim1") + xlab(" time steps")
B<-ggplot(df1, aes(x = df1$x , y = df1$dim2)) + geom_line(color= "blue") + ylab(label="Dim 2") + xlab(" time steps")
C<- ggplot(df1, aes(x = df1$x , y = df1$dim3)) + geom_line(color= "blue") + ylab(label="Dim 3") + xlab(" time steps")
D<-ggplot(df1, aes(x = df1$x , y = df1$dim4)) + geom_line(color= "blue") + ylab(label="Dim 4") + xlab("time steps")
E<-ggplot(df1, aes(x = df1$x , y = df1$dim5)) + geom_line(color= "blue") + ylab(label="Dim 5") + xlab("time steps")
F<- ggplot(df1, aes(x = df1$x , y = df1$dim6)) + geom_line(color= "blue") + ylab(label="Dim 6") + xlab("time steps")
G<-ggplot(df1, aes(x = df1$x , y = df1$dim7)) + geom_line(color= "blue") + ylab(label="Dim 7") + xlab(" time steps")
plot_grid(A, B,C,D,E,F,G, labels = "AUTO")
#######
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment