# Author: Himanshu Sikaria | |
# Title : Kaggle TFI Restaurant Revenue Prediction tutorial | |
# Model: Random Forest | |
library(randomForest) | |
require(party) | |
library(dplyr) | |
library(reshape) | |
# Reading the data | |
train <- read.csv("~/Downloads/train.csv") | |
test <- read.csv("~/Downloads/test.csv") | |
# Removing outliers | |
train <- train[train$revenue < 16000000,] | |
# Combining train and test | |
target = train$revenue | |
train_row = nrow(train) | |
train$revenue <- NULL | |
full = rbind(train, test) | |
# Plotting histogram for P1 to P37 | |
d <- melt(train[,-c(1:5)]) | |
ggplot(d,aes(x = value)) + | |
facet_wrap(~variable,scales = "free_x") + | |
geom_histogram() | |
# Spliting date | |
full$year <- substr(as.character(full$Open.Date),7,10) %>% as.factor() | |
full$month <- substr(as.character(full$Open.Date),1,2) %>% as.factor() | |
full$day <- substr(as.character(full$Open.Date),4,5) %>% as.numeric() | |
full$Date <- as.Date(strptime(full$Open.Date, "%m/%d/%Y")) | |
# How old the restaurant is | |
full$days <- as.numeric(as.Date("2014-02-02")-full$Date) | |
full$months <- as.numeric(as.Date("2014-02-02")-full$Date) / 30 | |
qplot(revenue, month, data=train) + geom_smooth() + ggtitle("Life of restaurant (months) vs Revenue") | |
# Removing columns which are not to be used | |
full$Id <- full$Open.Date <- full$Date <- full$City <- NULL | |
# Spliting into train and test | |
train = full[1:train_row,] | |
train$revenue = target | |
test = full[(train_row+1):nrow(full),] | |
row.names(test) = NULL | |
# Random Forest | |
set.seed(147) | |
fit = randomForest(revenue~., train, ntree = 1000) | |
varImpPlot(fit) | |
pred = predict(fit, test, type = "response") | |
# Preparing the required output format | |
test <- read.csv("~/Downloads/test.csv") | |
final = data.frame(ID = test$Id, Prediction = pred) | |
colnames(final)[2] = "Prediction" | |
write.csv(final, "tfi_rf.csv", row.names = F) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment