Skip to content

Instantly share code, notes, and snippets.

@nqbao
Last active June 25, 2016 15:48
Show Gist options
  • Save nqbao/2eaa9a3f34490fdded835098027d3cf1 to your computer and use it in GitHub Desktop.
Save nqbao/2eaa9a3f34490fdded835098027d3cf1 to your computer and use it in GitHub Desktop.
# pricing prediction with kmeans
# hypothesis: lotsize, bedrooms, bathrms, stories affects price
# we use kmeans to cluster into N clusters and use mean price
# of that's cluster as a prediction method
library(ggplot2)
library(plyr)
library(arimo)
ddf = arimo.getDDF('housing')
df = head(ddf, nrow(ddf))
splitdf <- function(dataframe, ratio=0.8, seed=NULL) {
if (!is.null(seed)) set.seed(seed)
index <- 1:nrow(dataframe)
trainindex = sample(1:nrow(dataframe), size=ratio*nrow(dataframe))
trainset <- dataframe[trainindex, ]
testset <- dataframe[-trainindex, ]
list(train=trainset,test=testset)
}
split = splitdf(df)
# see the histogram of price
qplot(price, data=df, geom="histogram")
features = split$train[3:6] # only keep numeric columns
kmeanClusters = kmeans(features, centers=5)
split$train$cluster = dfCluster$cluster
ggplot(data=subset(split$train, cluster == 1), aes(price)) + geom_density()
ggplot(data=subset(split$train, cluster == 1), aes(price)) + geom_histogram()
# find the mean price for each cluster
meanPrice = ddply(split$train, ~(cluster), summarise, mean = mean(price), sd = sd(price))
findCluster = function (row, centers) {
dist = function (x) {
sum((row - centers[x, ]) ^ 2)
}
which.min(sapply(seq(1, nrow(centers)), dist))
}
split$test$cluster = apply(split$test[, 3:6], 1, function(row) findCluster(row, dfCluster$centers) )
split$test$predictPrice =
sapply(split$test$cluster, function(x) sum(subset(meanPrice, `(cluster)` == x)))
# calculate error on test set using RMSE
rmse = function(y0, y1) {
sqrt(mean((y0 - y1) ^ 2))
}
rmse(split$test$price, split$test$predictPrice)
# r-squared
rq = function(y0, y1) {
y_mean = mean(y0)
ss_tot = sum((y0 - y_mean) ^ 2)
ss_res = sum((y0 - y1) ^ 2)
1 - (ss_res / ss_tot)
}
rq(split$test$price, split$test$predictPrice) # negative r-squared show this model is really poor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment