Skip to content

Instantly share code, notes, and snippets.

@yuj18
Last active August 12, 2016 01:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yuj18/8a655936bb6cddf8d4e9dc87c3b23a59 to your computer and use it in GitHub Desktop.
Save yuj18/8a655936bb6cddf8d4e9dc87c3b23a59 to your computer and use it in GitHub Desktop.
Region-dependent KNN search
public class KNNParallel extends FunctionAdapter implements Declarable {
public static final String ID = "KNNParallel";
@Override
public boolean isHA() {
return false;
}
@Override
public String getId() {
return ID;
}
@Override
public void init(Properties arg0) {
}
/**
* Find K nearest neighbors of a given query vector based on Euclidean
* distance. If data is partitioned over several servers, find 'local' KNNs
* in parallel on each partition of the data. One can obtain KNNs with
* respect to the entire dataset by simple post-processing of local KNN
* results, cf. KNN.java.
*/
@Override
public void execute(FunctionContext arg0) {
Object[] arguments = (Object[]) arg0.getArguments();
assert arguments.length == 4 : "4 arguments are expected: K, idField, valueField, queryVector";
int K = Integer.parseInt(arguments[0].toString().trim());
String idField = arguments[1].toString().trim();
String valueField = arguments[2].toString().trim();
String queryVec = arguments[3].toString().trim();
// KNN entities ordered by descending distance to the query point.
PriorityQueue<PdxInstance> knn = new PriorityQueue<PdxInstance>(1, distComparator);
Cache cache = CacheFactory.getAnyInstance();
Region<Object, PdxInstance> region = ((RegionFunctionContext)arg0).getDataSet();
// For a partitioned region, work on local data only.
if (PartitionRegionHelper.isPartitionedRegion(region)) {
region = PartitionRegionHelper.getLocalDataForContext((RegionFunctionContext)arg0);
}
if (region.size() > 0) {
Iterator<PdxInstance> itr = region.values().iterator();
while (itr.hasNext()) {
PdxInstance data = itr.next();
double dist = euclidean(queryVec, ((String) data.getField(valueField)).trim());
// Add a new serialized KNN entity to the queue.
PdxInstance knnEntity = KNNEntityToPdxInstance(cache, idField, data.getField(idField).toString(),
valueField, ((String) data.getField(valueField)).trim(), dist);
knn.offer(knnEntity);
// If more than K entities in the queue, remove the one
// with the maximum distance to the query point.
if (knn.size() > K) {
knn.poll();
}
}
}
// Return K serialized KNN entities. Serialization and conversion of
// PriorityQueue to ArrayList is to enable JSON
// conversion, so that the function can be invoked through REST API.
arg0.getResultSender().lastResult(new ArrayList<>(knn));
}
/**
* Euclidean distance.
*
* @param vec1
* Space separated numerical string.
* @param vec2
* Space separated numerical string.
* @return Euclidean distance.
*/
public double euclidean(String vec1, String vec2) {
String[] v1 = vec1.split("\\s+");
String[] v2 = vec2.split("\\s+");
assert v1.length == v2.length : "euclidean(): Input vectors have different lengths.";
double res = 0;
for (int i = 0; i < v1.length; i++) {
res += Math.pow((Double.parseDouble(v1[i]) - Double.parseDouble(v2[i])), 2);
}
return Math.sqrt(res);
}
/**
* Comparator for serialized KNN entity comparison based on the distance to
* the query point.
*/
public static Comparator<PdxInstance> distComparator = new Comparator<PdxInstance>() {
@Override
public int compare(PdxInstance a, PdxInstance b) {
double diff = (double) a.getField("dist") - (double) b.getField("dist");
if(diff > 0) {
return -1;
} else if (diff < 0) {
return 1;
} else {
return 0;
}
}
};
/**
* Create a PdxInstance for a KNN entity.
*
* @param gemfireCache
* GemFire cache.
* @param idField
* Name of the ID field of a KNN entity.
* @param id
* ID of a KNN entity.
* @param valueField
* Name of the value field of a KNN entity.
* @param value
* Vector representation of a KNN entity with the vector
* represented by a space separated string.
* @param dist
* Distance to the query point.
* @return PdxInstance of a KNN entity.
*/
protected PdxInstance KNNEntityToPdxInstance(Cache gemfireCache, String idField, String id, String valueField,
String value, Double dist) {
PdxInstanceFactory pdxInstanceFactory = gemfireCache
.createPdxInstanceFactory("io.pivotal.ds.gemfire.entity.KNNEntity");
pdxInstanceFactory.writeString(idField, id);
pdxInstanceFactory.writeString(valueField, value);
pdxInstanceFactory.writeDouble("dist", dist);
return pdxInstanceFactory.create();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment