Last active
February 25, 2020 19:16
-
-
Save m-Py/a844fe03838a4f3de017d76f2f18d8ae to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Martin Papenberg | |
# Year: 2019 | |
# Perform fast KNN classifier using RANN for nearest neighbour search | |
library("RANN") | |
library("data.table") | |
# param data: The numeric data matrix used | |
# param labels: the labels to predict | |
# param k: The k used in KNN | |
# return: The predicted KNN labels | |
knn_rann_dt <- function(data, labels, k = 10) { | |
data <- as.matrix(data) | |
# use numeric representation of factor levels | |
labels <- factor(labels) | |
factor_levels <- levels(labels) | |
labels <- as.numeric(labels) | |
# imperfect approximation of removing self as neighbour, | |
# just removing first column | |
nn_idx <- nn2(data, k = min(k, nrow(data)))$nn.idx[, -1] | |
# convert indices to category | |
nn_categories <- labels[nn_idx] | |
# restore dimensionality of nearest neighbour matrix | |
dim(nn_categories) <- dim(nn_idx) | |
# By category: determine the number of nearest neighbours having | |
# this category | |
nn_by_category <- function(i) { | |
colSums(t(nn_categories) == i) | |
} | |
ncats <- length(factor_levels) | |
sum_nn_by_category <- sapply(1:ncats, nn_by_category) | |
# use `data.table` to get index of maximum column, which corresponds | |
# to the most frequent category across the nearest neighbours | |
sum_nn_by_category <- as.data.table(sum_nn_by_category) | |
sum_nn_by_category[, maximum_element := do.call(pmax, .SD), .SDcols = 1:ncats] | |
factor_levels[sum_nn_by_category[, maximum_column := max.col(.SD), .SDcols = 1:ncats]$maximum_column] | |
} | |
## Some example applications: | |
knns <- knn_rann_dt(iris[, 1:4], iris[, 5], k = 10) | |
mean(knns == iris$Species) # performance of KNN classifier | |
# randomly generate some data for testing running time. | |
# Runs KNN for N = 100000 in 0.3 sec | |
# in ~ 5 sec for N = 1 million | |
N <- 1000000 | |
data <- rnorm(N) | |
labels <- sample(1:4, size = N, replace = TRUE) | |
start <- Sys.time() | |
knns1 <- knn_rann_dt(data, labels) | |
Sys.time() - start |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment