Last active
March 17, 2016 03:55
-
-
Save zufri/cca200fcaafc90ab5ba0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ----init--- | |
if (nchar(Sys.getenv("SPARK_HOME")) < 1) { | |
#sparkのインストールディレクトリに書き換え | |
Sys.setenv(SPARK_HOME = "D:/tools/spark-1.6.1-bin") | |
} | |
if(nchar(Sys.getenv("SPARKR_SUBMIT_ARGS")) <1 ){ | |
Sys.setenv('SPARKR_SUBMIT_ARGS'='"--packages" "com.databricks:spark-csv_2.10:1.4.0" "sparkr-shell"') | |
} | |
library(SparkR, lib.loc = c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"))) | |
library(randomForest) | |
seed_var<-123 | |
sc <- sparkR.init(master = "local[*]", sparkEnvir = list(spark.driver.memory="8g"),sparkPackages="com.databricks:spark-csv_2.10:1.4.0") | |
sqlContext <- sparkRSQL.init(sc) | |
# ----init--- | |
data_file_name<-paste0("data_1m") | |
# dataset | |
# 実際のデータセットのパスに書き換える | |
ds<-read.df(sqlContext,paste(getwd(),"data/training",data_file_name,"part-00000",sep="/"),source = "com.databricks.spark.csv",header="true") | |
ds_rowcount<-nrow(ds) | |
#feature list | |
cols <- c("Month", "DayofMonth", "DayOfWeek", "DepTime", "UniqueCarrier", | |
"Origin", "Dest", "Distance","IsDepDelay15Min") | |
# filter data as vector | |
filtered_ds <- select(ds,cols) | |
#-----common function----- | |
get_metrics <- function(predicted, actual) { | |
tp = length(which(predicted == 1 & actual == 1)) | |
tn = length(which(predicted == 0 & actual == 0)) | |
fp = length(which(predicted == 1 & actual == 0)) | |
fn = length(which(predicted == 0 & actual == 1)) | |
precision = tp / (tp+fp) | |
recall = tp / (tp+fn) | |
F1 = 2*precision*recall / (precision+recall) | |
accuracy = (tp+tn) / (tp+tn+fp+fn) | |
v = c(precision, recall, F1, accuracy) | |
v | |
} | |
# keep only k category for categorical data. | |
topK <- function(x,k){ | |
x <- as.factor(x) | |
tbl <- tabulate(x) | |
names(tbl) <- levels(x) | |
x <- as.character(x) | |
levelsToKeep <- names(tail(sort(tbl),k)) | |
x[!(x %in% levelsToKeep)] <- 'rest' | |
factor(x) | |
} | |
normalize_data <- function(df) { | |
df$Month <- as.factor(df$Month) | |
df$DayofMonth <- as.factor(df$DayofMonth) | |
df$DayOfWeek <- as.factor(df$DayOfWeek) | |
# only extract departure hour | |
df$DepTime <- base::ifelse(nchar(base:::substr(df$DepTime,1,nchar(df$DepTime)-2))>0,base:::substr(df$DepTime,1,nchar(df$DepTime)-2),"0") | |
df$DepTime <- as.integer(df$DepTime) | |
df$IsDepDelay15Min <- as.factor(df$IsDepDelay15Min) | |
df$Distance<- as.integer(df$Distance) | |
df$Origin <- topK(df$Origin, 25) | |
df$Dest <- topK(df$Dest, 25) | |
df$UniqueCarrier <- topK(df$UniqueCarrier, 25) | |
categ <- lapply(df[,c("Origin", "Dest","UniqueCarrier")], as.factor) | |
#create binary variable for categories | |
bin_df <- model.matrix(~ . -1, data=categ, contrasts.arg = lapply(categ, contrasts, contrasts=FALSE)) | |
# remove columns because they will be replaced by binary variables for category | |
df$Dest <- NULL | |
df$UniqueCarrier <- NULL | |
df$Origin <- NULL | |
out <- cbind(df, bin_df) | |
out | |
} | |
#-----common function----- | |
# split training data and test data | |
ds_trainset<-sample(filtered_ds,FALSE,0.8,seed_var) | |
local_ds_trainset<-normalize_data(collect(ds_trainset)) | |
ds_testset<-except(filtered_ds,ds_trainset) | |
local_ds_testset<-normalize_data(collect(ds_testset)) | |
# below operations are executed in R. | |
cols <- intersect(names(local_ds_trainset), names(local_ds_testset)) | |
local_ds_trainset<-local_ds_trainset[,cols] | |
local_ds_testset<-local_ds_testset[,cols] | |
test_is_delay<-local_ds_testset$IsDepDelay15Min | |
local_ds_testset$IsDepDelay15Min<-NULL | |
rf<-randomForest(IsDepDelay15Min ~ .,data = base::as.data.frame(local_ds_trainset),ntree=40) | |
rf.pr <- predict(rf, newdata=base::as.data.frame(local_ds_testset)) | |
m.rf = get_metrics(as.vector.factor(rf.pr), as.vector.factor(test_is_delay)) | |
print(sprintf("Random Forest: precision=%0.2f, recall=%0.2f, F1=%0.2f, accuracy=%0.2f", m.rf[1], m.rf[2], m.rf[3], m.rf[4])) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment