Last active
June 25, 2016 15:48
-
-
Save nqbao/2eaa9a3f34490fdded835098027d3cf1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pricing prediction with kmeans | |
# hypothesis: lotsize, bedrooms, bathrms, stories affects price | |
# we use kmeans to cluster into N clusters and use mean price | |
# of that's cluster as a prediction method | |
library(ggplot2) | |
library(plyr) | |
library(arimo) | |
ddf = arimo.getDDF('housing') | |
df = head(ddf, nrow(ddf)) | |
splitdf <- function(dataframe, ratio=0.8, seed=NULL) { | |
if (!is.null(seed)) set.seed(seed) | |
index <- 1:nrow(dataframe) | |
trainindex = sample(1:nrow(dataframe), size=ratio*nrow(dataframe)) | |
trainset <- dataframe[trainindex, ] | |
testset <- dataframe[-trainindex, ] | |
list(train=trainset,test=testset) | |
} | |
split = splitdf(df) | |
# see the histogram of price | |
qplot(price, data=df, geom="histogram") | |
features = split$train[3:6] # only keep numeric columns | |
kmeanClusters = kmeans(features, centers=5) | |
split$train$cluster = dfCluster$cluster | |
ggplot(data=subset(split$train, cluster == 1), aes(price)) + geom_density() | |
ggplot(data=subset(split$train, cluster == 1), aes(price)) + geom_histogram() | |
# find the mean price for each cluster | |
meanPrice = ddply(split$train, ~(cluster), summarise, mean = mean(price), sd = sd(price)) | |
findCluster = function (row, centers) { | |
dist = function (x) { | |
sum((row - centers[x, ]) ^ 2) | |
} | |
which.min(sapply(seq(1, nrow(centers)), dist)) | |
} | |
split$test$cluster = apply(split$test[, 3:6], 1, function(row) findCluster(row, dfCluster$centers) ) | |
split$test$predictPrice = | |
sapply(split$test$cluster, function(x) sum(subset(meanPrice, `(cluster)` == x))) | |
# calculate error on test set using RMSE | |
rmse = function(y0, y1) { | |
sqrt(mean((y0 - y1) ^ 2)) | |
} | |
rmse(split$test$price, split$test$predictPrice) | |
# r-squared | |
rq = function(y0, y1) { | |
y_mean = mean(y0) | |
ss_tot = sum((y0 - y_mean) ^ 2) | |
ss_res = sum((y0 - y1) ^ 2) | |
1 - (ss_res / ss_tot) | |
} | |
rq(split$test$price, split$test$predictPrice) # negative r-squared show this model is really poor |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment