Skip to content

Instantly share code, notes, and snippets.

@DexGroves
Created April 15, 2016 16:28
Show Gist options
  • Save DexGroves/90acc3e501bcc9ec20e30265d37eb36d to your computer and use it in GitHub Desktop.
Save DexGroves/90acc3e501bcc9ec20e30265d37eb36d to your computer and use it in GitHub Desktop.
Fit mxnet on a spiral
library("mlbench")
library("ggplot2")
library("mxnet")
plot_mxmodel <- function(model, data) {
x <- seq(from = min(data), to = max(data), length.out = 500)
d2 <- as.matrix(expand.grid(x, x))
mx_pred <- predict(model, d2)
pred_label <- t(mx_pred)
df_d2 <- expand.grid(x = x, y = x)
df_d2$p <- pred_label
ggplot() +
geom_tile(data = df_d2, aes(x, y, fill = p)) +
scale_fill_continuous(low = "sandybrown", high = "steelblue2") +
geom_point(data = data.frame(data),
aes(x = X1, y = X2,colour = factor(train_y)), size = 1) +
scale_colour_manual(values = c("red", "blue"))
}
spiral_data <- mlbench.spirals(1000, 1.5, 0)
plot(spiral_data)
train_x <- data.matrix(spiral_data$x)
train_y <- as.numeric(spiral_data$classes) - 1
mx.set.seed(0)
mx_model <- mx.mlp(train_x,
train_y,
num.round = 1000,
hidden_node = c(7, 7, 7),
activation = "relu",
out_activation = "logistic",
out_node = 1,
array.batch.size = 50,
learning.rate = 0.03,
momentum = 0.1,
array.layout = "rowmajor",
initializer = mx.init.normal(1),
eval.metric = mx.metric.rmse)
plot_mxmodel(mx_model, train_x)
# predict(mx_model, train_x) %>% t %>% summary
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment