Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@zufri
Last active March 17, 2016 03:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zufri/cca200fcaafc90ab5ba0 to your computer and use it in GitHub Desktop.
Save zufri/cca200fcaafc90ab5ba0 to your computer and use it in GitHub Desktop.
# ----init---
if (nchar(Sys.getenv("SPARK_HOME")) < 1) {
#sparkのインストールディレクトリに書き換え
Sys.setenv(SPARK_HOME = "D:/tools/spark-1.6.1-bin")
}
if(nchar(Sys.getenv("SPARKR_SUBMIT_ARGS")) <1 ){
Sys.setenv('SPARKR_SUBMIT_ARGS'='"--packages" "com.databricks:spark-csv_2.10:1.4.0" "sparkr-shell"')
}
library(SparkR, lib.loc = c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib")))
library(randomForest)
seed_var<-123
sc <- sparkR.init(master = "local[*]", sparkEnvir = list(spark.driver.memory="8g"),sparkPackages="com.databricks:spark-csv_2.10:1.4.0")
sqlContext <- sparkRSQL.init(sc)
# ----init---
data_file_name<-paste0("data_1m")
# dataset
# 実際のデータセットのパスに書き換える
ds<-read.df(sqlContext,paste(getwd(),"data/training",data_file_name,"part-00000",sep="/"),source = "com.databricks.spark.csv",header="true")
ds_rowcount<-nrow(ds)
#feature list
cols <- c("Month", "DayofMonth", "DayOfWeek", "DepTime", "UniqueCarrier",
"Origin", "Dest", "Distance","IsDepDelay15Min")
# filter data as vector
filtered_ds <- select(ds,cols)
#-----common function-----
get_metrics <- function(predicted, actual) {
tp = length(which(predicted == 1 & actual == 1))
tn = length(which(predicted == 0 & actual == 0))
fp = length(which(predicted == 1 & actual == 0))
fn = length(which(predicted == 0 & actual == 1))
precision = tp / (tp+fp)
recall = tp / (tp+fn)
F1 = 2*precision*recall / (precision+recall)
accuracy = (tp+tn) / (tp+tn+fp+fn)
v = c(precision, recall, F1, accuracy)
v
}
# keep only k category for categorical data.
topK <- function(x,k){
x <- as.factor(x)
tbl <- tabulate(x)
names(tbl) <- levels(x)
x <- as.character(x)
levelsToKeep <- names(tail(sort(tbl),k))
x[!(x %in% levelsToKeep)] <- 'rest'
factor(x)
}
normalize_data <- function(df) {
df$Month <- as.factor(df$Month)
df$DayofMonth <- as.factor(df$DayofMonth)
df$DayOfWeek <- as.factor(df$DayOfWeek)
# only extract departure hour
df$DepTime <- base::ifelse(nchar(base:::substr(df$DepTime,1,nchar(df$DepTime)-2))>0,base:::substr(df$DepTime,1,nchar(df$DepTime)-2),"0")
df$DepTime <- as.integer(df$DepTime)
df$IsDepDelay15Min <- as.factor(df$IsDepDelay15Min)
df$Distance<- as.integer(df$Distance)
df$Origin <- topK(df$Origin, 25)
df$Dest <- topK(df$Dest, 25)
df$UniqueCarrier <- topK(df$UniqueCarrier, 25)
categ <- lapply(df[,c("Origin", "Dest","UniqueCarrier")], as.factor)
#create binary variable for categories
bin_df <- model.matrix(~ . -1, data=categ, contrasts.arg = lapply(categ, contrasts, contrasts=FALSE))
# remove columns because they will be replaced by binary variables for category
df$Dest <- NULL
df$UniqueCarrier <- NULL
df$Origin <- NULL
out <- cbind(df, bin_df)
out
}
#-----common function-----
# split training data and test data
ds_trainset<-sample(filtered_ds,FALSE,0.8,seed_var)
local_ds_trainset<-normalize_data(collect(ds_trainset))
ds_testset<-except(filtered_ds,ds_trainset)
local_ds_testset<-normalize_data(collect(ds_testset))
# below operations are executed in R.
cols <- intersect(names(local_ds_trainset), names(local_ds_testset))
local_ds_trainset<-local_ds_trainset[,cols]
local_ds_testset<-local_ds_testset[,cols]
test_is_delay<-local_ds_testset$IsDepDelay15Min
local_ds_testset$IsDepDelay15Min<-NULL
rf<-randomForest(IsDepDelay15Min ~ .,data = base::as.data.frame(local_ds_trainset),ntree=40)
rf.pr <- predict(rf, newdata=base::as.data.frame(local_ds_testset))
m.rf = get_metrics(as.vector.factor(rf.pr), as.vector.factor(test_is_delay))
print(sprintf("Random Forest: precision=%0.2f, recall=%0.2f, F1=%0.2f, accuracy=%0.2f", m.rf[1], m.rf[2], m.rf[3], m.rf[4]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment