Skip to content

Instantly share code, notes, and snippets.

@yuj18

yuj18/KNN.java Secret

Last active August 12, 2016 01:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yuj18/972685ba66df0b8725089bfdaf2ea6c7 to your computer and use it in GitHub Desktop.
Save yuj18/972685ba66df0b8725089bfdaf2ea6c7 to your computer and use it in GitHub Desktop.
KNN aggregate
public class KNN extends FunctionAdapter implements Declarable {
public static final String ID = "KNN";
@Override
public boolean isHA() {
return false;
}
@Override
public String getId() {
return ID;
}
@Override
public void init(Properties arg0) {
}
/**
* Find K nearest neighbors of a given query vector based on Euclidean
* distance.
*/
@Override
public void execute(FunctionContext arg0) {
Object[] arguments = (Object[]) arg0.getArguments();
assert arguments.length == 5 : "5 arguments are expected: regionName, K, idField, valueField, queryVector";
String regionName = arguments[0].toString().trim();
int K = Integer.parseInt(arguments[1].toString().trim());
Cache cache = CacheFactory.getAnyInstance();
Region<Object, PdxInstance> region = cache.getRegion(regionName);
Execution exe = FunctionService.onRegion(region).withArgs(Arrays.copyOfRange(arguments, 1, 5));
// Find 'local' KNNs on each partition of the data in parallel. If data
// is not partitioned, return KNNs with respect to the entire dataset.
ResultCollector<?, ?> resultCollector = exe.execute("KNNParallel");
@SuppressWarnings("unchecked")
ArrayList<ArrayList<PdxInstance>> result = (ArrayList<ArrayList<PdxInstance>>) resultCollector.getResult();
// Consolidate local KNN results.
PriorityQueue<PdxInstance> knn = new PriorityQueue<PdxInstance>(1, distComparator);
for (int i = 0; i < result.size(); i++) {
for (int j = 0; j < result.get(i).size(); j++) {
knn.offer(result.get(i).get(j));
// If more than K entities in the queue, remove the one
// with the maximum distance to the query point.
if (knn.size() > K) {
knn.poll();
}
}
}
arg0.getResultSender().lastResult(new ArrayList<>(knn));
}
/**
* Comparator for serialized KNN entity comparison based on the distance to
* the query point.
*/
public static Comparator<PdxInstance> distComparator = new Comparator<PdxInstance>() {
@Override
public int compare(PdxInstance a, PdxInstance b) {
double diff = (double) a.getField("dist") - (double) b.getField("dist");
if(diff > 0) {
return -1;
} else if (diff < 0) {
return 1;
} else {
return 0;
}
}
};
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment