Skip to content

Instantly share code, notes, and snippets.

@decisionmechanics
Created March 21, 2017 18:56
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save decisionmechanics/691aee47b1efd1fcb70c7a41b4c160ed to your computer and use it in GitHub Desktop.
Save decisionmechanics/691aee47b1efd1fcb70c7a41b4c160ed to your computer and use it in GitHub Desktop.
Predicting wine quality using a random forest classifier in SparkR
library(readr)
library(dplyr)
url <- "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv"
df <-
read_delim(url, delim = ";") %>%
dplyr::mutate(taste = as.factor(ifelse(quality < 6, "bad", ifelse(quality > 6, "good", "average")))) %>%
dplyr::select(-quality)
df <- dplyr::mutate(df, id = as.integer(rownames(df)))
print(df, n=25)
table(df$taste)
Sys.setenv(SPARK_HOME="/home/andrew/spark-2.1.0-bin-hadoop2.7")
library(SparkR, lib.loc=c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib")))
sparkR.session(master="local[*]")
ddf <- createDataFrame(df)
seed <- 12345
training_ddf <- sample(ddf, withReplacement=FALSE, fraction=0.7, seed=seed)
test_ddf <- except(ddf, training_ddf)
model <- spark.randomForest(training_ddf, taste ~ ., type="classification", seed=seed)
summary(model)
predictions <- predict(model, test_ddf)
prediction_df <- collect(select(predictions, "id", "prediction"))
actual_vs_predicted <-
dplyr::inner_join(df, prediction_df, "id") %>%
dplyr::select(id, actual = taste, predicted = prediction)
mean(actual_vs_predicted$actual == actual_vs_predicted$predicted)
table(actual_vs_predicted$actual, actual_vs_predicted$predicted)
model_file_path <- "/home/andrew/wine_random_forest_model"
write.ml(model, model_file_path)
new_model <- read.ml(model_file_path)
summary(new_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment