Skip to content

Instantly share code, notes, and snippets.

@m-Py
Last active February 25, 2020 19:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save m-Py/a844fe03838a4f3de017d76f2f18d8ae to your computer and use it in GitHub Desktop.
Save m-Py/a844fe03838a4f3de017d76f2f18d8ae to your computer and use it in GitHub Desktop.
# Author: Martin Papenberg
# Year: 2019
# Perform fast KNN classifier using RANN for nearest neighbour search
library("RANN")
library("data.table")
# param data: The numeric data matrix used
# param labels: the labels to predict
# param k: The k used in KNN
# return: The predicted KNN labels
knn_rann_dt <- function(data, labels, k = 10) {
data <- as.matrix(data)
# use numeric representation of factor levels
labels <- factor(labels)
factor_levels <- levels(labels)
labels <- as.numeric(labels)
# imperfect approximation of removing self as neighbour,
# just removing first column
nn_idx <- nn2(data, k = min(k, nrow(data)))$nn.idx[, -1]
# convert indices to category
nn_categories <- labels[nn_idx]
# restore dimensionality of nearest neighbour matrix
dim(nn_categories) <- dim(nn_idx)
# By category: determine the number of nearest neighbours having
# this category
nn_by_category <- function(i) {
colSums(t(nn_categories) == i)
}
ncats <- length(factor_levels)
sum_nn_by_category <- sapply(1:ncats, nn_by_category)
# use `data.table` to get index of maximum column, which corresponds
# to the most frequent category across the nearest neighbours
sum_nn_by_category <- as.data.table(sum_nn_by_category)
sum_nn_by_category[, maximum_element := do.call(pmax, .SD), .SDcols = 1:ncats]
factor_levels[sum_nn_by_category[, maximum_column := max.col(.SD), .SDcols = 1:ncats]$maximum_column]
}
## Some example applications:
knns <- knn_rann_dt(iris[, 1:4], iris[, 5], k = 10)
mean(knns == iris$Species) # performance of KNN classifier
# randomly generate some data for testing running time.
# Runs KNN for N = 100000 in 0.3 sec
# in ~ 5 sec for N = 1 million
N <- 1000000
data <- rnorm(N)
labels <- sample(1:4, size = N, replace = TRUE)
start <- Sys.time()
knns1 <- knn_rann_dt(data, labels)
Sys.time() - start
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment