Skip to content

Instantly share code, notes, and snippets.

Last active September 12, 2016 16:46
Show Gist options
  • Save MikeDepies/679540cfbace153d4358 to your computer and use it in GitHub Desktop.
Save MikeDepies/679540cfbace153d4358 to your computer and use it in GitHub Desktop.
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseAccumulation;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
* Requires RC1<br/>
* This is a sample implementation of K-Nearest Neighbor. KNN doesn't create a latent/hidden model, nor does it train. It simply uses it's dataset as a look up table.
* Then grabs the closest k data point/records. From there it is able to look at the labels for the K nearest neighbors and make a decision on what label to predict the incoming
* data. In this implementation, it just takes the largest class representation.
* <br/><br/><br/>
* Note: Would like to implement multiple neighborhood aggregation functionality, similar to the robustness of the distanceMeasure.
* @author Mike Depies
public class KNN {
DataSet data;
BaseAccumulation distanceMeasure;
int numberOfNeighbors;
* Instantiates KNN with the default set up. K=5, distanceMeasure = Euclidean
public KNN() {
* Instantiates KNN with Euclidean as it's distanceMeasure and a user set neighborhood size.
* @param numberOfNeighbors - The number of neighbors to look at, also known as K.
public KNN(int numberOfNeighbors) {
this(numberOfNeighbors, new EuclideanDistance());
* Instantiates KNN with the specified distanceMeasure and neighborhood size.
* @param numberOfNeighbors - The number of neighbors to look at, also known as K.
* @param distanceMeasure - The kind of distance measurement. (Euclidean, Manhatten, ...)
public KNN(int numberOfNeighbors, BaseAccumulation distanceMeasure) {
this.distanceMeasure = distanceMeasure;
this.numberOfNeighbors = numberOfNeighbors;
* KNN doesn't learn anything intrinsic about it's data. Instead, on predict call; each record in the set will be looked at.
* This model employs Lazy Learning/Instance-based Learning.
* @param data - Expected to be as much relevant data as possible.
public void fit(DataSet data) { = data;
* Takes in n rows of input records and spits out n predictions.<br/><br/>
* NOTE: There may be room for optimization in handling the multi-record input.
* @param input - The matrix (or single vector) of input to test.
* @return a corresponding array of label index predictions.
public int[] predict(INDArray input) {
INDArray features = data.getFeatures();
INDArray labels = data.getLabels();
int numberOfInputs = input.rows();
int[] predictions = new int[numberOfInputs];
boolean includeMeasure = false;
int offset = (includeMeasure) ? 1 : 0;//Hardcoded offset, this represents that column is used for distance in the findKNeighbors method.
INDArray distanceNDArray = measureDistance(input, features, distanceMeasure);
//This is a 3d array
INDArray nearestNeighbors = findKNeighbors(numberOfNeighbors, includeMeasure, distanceNDArray, labels);
//Collect up the label columns, offset represents our distanceVector column
//dimension 2 refers to distanceMeasure & labels.
//offset is the number of distanceMeasures (should be 1, but is set up to handle more).
int[] indexes = new int[nearestNeighbors.size(2) - offset];
for (int i=0; i < indexes.length; i++) {
indexes[i] = i + offset;
for (int i=0; i < numberOfInputs; i++) {
//Increment our index that refers to the input index
//Grab the largest count index of the label columns in of the k nearestNeighors for the ith input
INDArray measureAndLabelMatrix = nearestNeighbors.tensorAlongDimension(i, 2, 1);
INDArray labelMatrix = measureAndLabelMatrix.getColumns(indexes);
int predictedLabelIndex = Nd4j.getBlasWrapper().iamax(Nd4j.sum(labelMatrix, 0));
predictions[i] = predictedLabelIndex;
return predictions;
* Takes a feature and compares the distance between itself and each feature in a matrix.
* @param feature - The feature to measure against the matrix/
* @param featureMatrix - The matrix of data, likely from a dataset.
* @return a vector of distances that line up with the rows of the featureMatrix.
private INDArray measureDistance(INDArray feature, INDArray featureMatrix, BaseAccumulation distanceMeasure) {
int numberOfRows = featureMatrix.rows();
int numberOfInput = feature.rows();
INDArray distances = Nd4j.zeros(numberOfRows, numberOfInput);
for (int inputIndex=0; inputIndex < numberOfInput; inputIndex++) {
for (int rowIndex=0; rowIndex < numberOfRows; rowIndex++) {
distances.put(rowIndex, inputIndex, distanceMeasure.currentResult());
return distances;
* Pairs up the measured distances for each observation with the labels. Sorts them on the distance metric, and grabs the k smallest distance records.
* @param k - the size of the neighborhood.
* @param returnDistanceMeasure - Whether we should include the measure in our neighborhood information schema.
* @param distanceNDArray - The measured distance vector. Needs to be the same length as the number of rows for labelMatrix.
* @param labelMatrix - The label matrix of the dataset at hand.
* @return Returns the neighborhood matrix, comprised of distance and labels.
* dims are [0 = the input example, 1=instances in the neighborhood running 0 to k, 2= distance and labels (index 0 is distance measure, the remaining are labels) ]
private INDArray findKNeighbors(int k, boolean returnDistanceMeasure, INDArray distanceNDArray, INDArray labelMatrix) {
//Horrizontally merge the vector with the Matrix.
int labelLength = labelMatrix.columns();
int distLength = (returnDistanceMeasure) ? 1 : 0;
int infoSize = distLength + labelLength;
INDArray mergedSet;
INDArray kNeighbors3d = Nd4j.create(distanceNDArray.columns(), k, infoSize);
//Create the indexes that decide what info we are including the distance measure or not
int[] indexes = new int[infoSize];
for (int i=0; i < indexes.length; i++)
indexes[i] = i + (1 - distLength);
for (int z = 0; z < distanceNDArray.columns(); z++) {
mergedSet = Nd4j.hstack(distanceNDArray.getColumn(z), labelMatrix);
INDArray sortedSet = Nd4j.sortRows(mergedSet, 0, true).getColumns(indexes);
INDArray kNeighbors = kNeighbors3d.tensorAlongDimension(z, 2, 1);
for (int i=0; i < k; i++) {
kNeighbors.putRow(i, sortedSet.getRow(i));
return kNeighbors3d;
Copy link

I get the following error:

Description Resource Path Location Type
The method setCurrentResult(int) is undefined for the type BaseAccumulation /UnDl4j/src/main/java line 111 Java Problem

I think I have everything updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment