-
-
Save yuj18/972685ba66df0b8725089bfdaf2ea6c7 to your computer and use it in GitHub Desktop.
KNN aggregate
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class KNN extends FunctionAdapter implements Declarable { | |
public static final String ID = "KNN"; | |
@Override | |
public boolean isHA() { | |
return false; | |
} | |
@Override | |
public String getId() { | |
return ID; | |
} | |
@Override | |
public void init(Properties arg0) { | |
} | |
/** | |
* Find K nearest neighbors of a given query vector based on Euclidean | |
* distance. | |
*/ | |
@Override | |
public void execute(FunctionContext arg0) { | |
Object[] arguments = (Object[]) arg0.getArguments(); | |
assert arguments.length == 5 : "5 arguments are expected: regionName, K, idField, valueField, queryVector"; | |
String regionName = arguments[0].toString().trim(); | |
int K = Integer.parseInt(arguments[1].toString().trim()); | |
Cache cache = CacheFactory.getAnyInstance(); | |
Region<Object, PdxInstance> region = cache.getRegion(regionName); | |
Execution exe = FunctionService.onRegion(region).withArgs(Arrays.copyOfRange(arguments, 1, 5)); | |
// Find 'local' KNNs on each partition of the data in parallel. If data | |
// is not partitioned, return KNNs with respect to the entire dataset. | |
ResultCollector<?, ?> resultCollector = exe.execute("KNNParallel"); | |
@SuppressWarnings("unchecked") | |
ArrayList<ArrayList<PdxInstance>> result = (ArrayList<ArrayList<PdxInstance>>) resultCollector.getResult(); | |
// Consolidate local KNN results. | |
PriorityQueue<PdxInstance> knn = new PriorityQueue<PdxInstance>(1, distComparator); | |
for (int i = 0; i < result.size(); i++) { | |
for (int j = 0; j < result.get(i).size(); j++) { | |
knn.offer(result.get(i).get(j)); | |
// If more than K entities in the queue, remove the one | |
// with the maximum distance to the query point. | |
if (knn.size() > K) { | |
knn.poll(); | |
} | |
} | |
} | |
arg0.getResultSender().lastResult(new ArrayList<>(knn)); | |
} | |
/** | |
* Comparator for serialized KNN entity comparison based on the distance to | |
* the query point. | |
*/ | |
public static Comparator<PdxInstance> distComparator = new Comparator<PdxInstance>() { | |
@Override | |
public int compare(PdxInstance a, PdxInstance b) { | |
double diff = (double) a.getField("dist") - (double) b.getField("dist"); | |
if(diff > 0) { | |
return -1; | |
} else if (diff < 0) { | |
return 1; | |
} else { | |
return 0; | |
} | |
} | |
}; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment