Created
April 20, 2015 17:54
-
-
Save TATABOX42/e1db659baee306cc0cc5 to your computer and use it in GitHub Desktop.
K-means introduction in R: code used for post: http://btovar.com/2015/04/introduction-to-k-means-in-r/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# *********************************************** | |
# *********************************************** | |
# Author: Benjamin Tovar | |
# Date: April 20, 2015 | |
# Post: http://btovar.com/2015/04/introduction-to-k-means-in-r/ | |
# *********************************************** | |
# *********************************************** | |
library(ggplot2) | |
library(reshape) | |
##################################################### | |
# Case study 1: the basics of the K-means in R | |
##################################################### | |
set.seed(1121) | |
## Set three vectors, each with 100 entries | |
l <- 100 | |
## x with a mean of 1 and a sd of 0.5 | |
x <- rnorm(l,1,0.5) | |
mean(x);var(x);sd(x);summary(x) | |
## y with a mean of 5 and a sd of 1 | |
y <- rnorm(l,5,1) | |
mean(y);var(y);sd(y);summary(y) | |
## z with a mean of 4 and a sd of 0.4 | |
z <- rnorm(l,4,0.4) | |
mean(z);var(z);sd(z);summary(z) | |
# set dataset | |
dataset <- data.frame(x=x,y=y,z=z) | |
dataset_m <- melt(dataset) | |
dataset_m <- cbind(id=1:(l*3),dataset_m) | |
colnames(dataset_m) <- c("id","class","value") | |
# Now plot the distribution of the artificial dataset in order to check out the | |
# distribution of the entries in the dataset: | |
ggplot(dataset_m) + | |
aes(x=value) + | |
geom_density(aes(fill="red",colour="red"),alpha=0.4) + | |
theme(legend.position="none") + | |
labs(title="Density plot of all values (omitting classes)") | |
ggplot(dataset_m) + | |
aes(x=value) + | |
geom_density(aes(fill=class,colour=class),alpha=0.4) + | |
labs(title="Density plot for all classes") | |
# Scatter plot of dataset | |
ggplot(dataset_m) + | |
aes(x=id,y=value) + | |
geom_point(aes(shape=class),size=2.5,alpha=0.9) + | |
labs(title="Scatter plot of dataset") | |
## set the number of centroids with parameter k | |
k <- 2 | |
## run the Kmeans algorithm | |
km <- kmeans(dataset_m$value,centers=k) | |
dataset_km <- cbind(dataset_m,cluster=as.factor(km$cluster)) | |
# plot each sample | |
ggplot(dataset_km) + | |
aes(x=id,y=value) + | |
geom_point(aes(shape=class,colour=cluster),size=2.5,alpha=0.9) + | |
labs(title="Dataset K-means clustering | k=2") | |
## set the number of centroids with parameter k | |
k <- 3 | |
## run the Kmeans algorithm | |
km <- kmeans(dataset_m$value,centers=k) | |
dataset_km <- cbind(dataset_m,cluster=as.factor(km$cluster)) | |
# plot each sample | |
ggplot(dataset_km) + | |
aes(x=id,y=value) + | |
geom_point(aes(shape=class,colour=cluster),size=2.5,alpha=0.9) + | |
labs(title="Dataset K-means clustering | k=3") | |
## set the number of centroids with parameter k | |
k <- 4 | |
## run the Kmeans algorithm | |
km <- kmeans(dataset_m$value,centers=k) | |
dataset_km <- cbind(dataset_m,cluster=as.factor(km$cluster)) | |
# plot each sample | |
ggplot(dataset_km) + | |
aes(x=id,y=value) + | |
geom_point(aes(shape=class,colour=cluster),size=2.5,alpha=0.9) + | |
labs(title="Dataset K-means clustering | k=4") | |
############## | |
# COMPUTE THE BEST VALUE | |
# FOR PARAMETER K | |
############## | |
# TRAIN THE MODEL ITERATING THE VALUES OF K | |
max_k_size <- 10 | |
error <- numeric(max_k_size) | |
for (i in 1:max_k_size){ | |
error[i] <- sum(kmeans(dataset,centers=i)$withinss) | |
} | |
# set data frame object | |
error <- data.frame(k=1:max_k_size,error=error) | |
ggplot(error) + | |
aes(x=k,y=error) + | |
geom_point(colour="#f96161",size=3) + | |
ylim(0,150) + | |
scale_x_continuous(breaks=seq(0, 15, 1)) + | |
geom_line(colour="#f96161") + | |
labs(title="Error calibration curve", | |
x="Number of Clusters (k)", | |
y="Within groups sum of squares") + | |
geom_vline(xintercept=3,colour="#32ab9f", linetype = "longdash") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment