Skip to content

Instantly share code, notes, and snippets.

@dwf
Created February 1, 2010 21:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dwf/292033 to your computer and use it in GitHub Desktop.
Save dwf/292033 to your computer and use it in GitHub Desktop.
A naive, still-sort-of-inefficient k-NN implementation in idiomatic ("vectorized") MATLAB.
function bestclass = knn(train_data, labels, example, k);
%kNN-- do k-nearest neighbours classification
%
% BESTCLASS = knn(TRAIN_DATA, LABELS, EXAMPLE, K)
%
% Takes TRAIN_DATA, a D x N matrix containing N training examples of dimension
% D; LABELS, an N-vector of the (positive integer) classes assigned to each
% column of TRAIN_DATA; EXAMPLE, a D-vector consisting of the example we
% are trying to classify; and K, the number of neighbours to use in
% classifying.
%
% Returns BESTCLASS, the predicted class of the test point.
%
% The K-nearest neighbour algorithm works exactly like the one-nearest
% neighbour algorithm (which chooses the class containing the example that is
% has minimum Euclidean distance to the test example) but instead of using only
% the closest neighbour it takes the K closest points and computes the
% majority vote. See http://en.wikipedia.org/wiki/KNN for more details.
%
% Compute the distances to each of the N training examples by duplicating
% the test vector using REPMAT, subtracting, elementwise squaring with .^2,
% and sum() to get the sums of each column, then sort them.
%
% NOTE: calling sort with two output arguments as I've done returns a vector
% of the sorted distances as the first output argument, and a vector of
% indices for the original positions of the sorted elements. i.e. ind(5)
% contains the index that element 5 of the sorted array originally appeared
% at in the unsorted array.
%
% By David Warde-Farley -- user AT cs dot toronto dot edu (user = dwf)
% Redistributable under the terms of the 3-clause BSD license
% (see http://www.opensource.org/licenses/bsd-license.php for details)
[val, ind] = sort(sum((repmat(example,1,size(train_data,2)) - train_data).^2));
% Create a vector to store the number of examples observed for each class
% among the K neighbours.
counts = zeros(max(labels),1);
% Loop through the k closest neighbours
for neighbour = 1:k,
% Get the label of the current neighbour from the LABELS vector
% and increment its count.
counts(labels(ind(neighbour))) = counts(labels(ind(neighbour))) + 1;
end;
% Take the class with the highest count (throw away the actual count
% but store the index in BESTCLASS).
[junk, bestclass] = max(counts);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment