-
-
Save yuj18/8a655936bb6cddf8d4e9dc87c3b23a59 to your computer and use it in GitHub Desktop.
Region-dependent KNN search
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class KNNParallel extends FunctionAdapter implements Declarable { | |
public static final String ID = "KNNParallel"; | |
@Override | |
public boolean isHA() { | |
return false; | |
} | |
@Override | |
public String getId() { | |
return ID; | |
} | |
@Override | |
public void init(Properties arg0) { | |
} | |
/** | |
* Find K nearest neighbors of a given query vector based on Euclidean | |
* distance. If data is partitioned over several servers, find 'local' KNNs | |
* in parallel on each partition of the data. One can obtain KNNs with | |
* respect to the entire dataset by simple post-processing of local KNN | |
* results, cf. KNN.java. | |
*/ | |
@Override | |
public void execute(FunctionContext arg0) { | |
Object[] arguments = (Object[]) arg0.getArguments(); | |
assert arguments.length == 4 : "4 arguments are expected: K, idField, valueField, queryVector"; | |
int K = Integer.parseInt(arguments[0].toString().trim()); | |
String idField = arguments[1].toString().trim(); | |
String valueField = arguments[2].toString().trim(); | |
String queryVec = arguments[3].toString().trim(); | |
// KNN entities ordered by descending distance to the query point. | |
PriorityQueue<PdxInstance> knn = new PriorityQueue<PdxInstance>(1, distComparator); | |
Cache cache = CacheFactory.getAnyInstance(); | |
Region<Object, PdxInstance> region = ((RegionFunctionContext)arg0).getDataSet(); | |
// For a partitioned region, work on local data only. | |
if (PartitionRegionHelper.isPartitionedRegion(region)) { | |
region = PartitionRegionHelper.getLocalDataForContext((RegionFunctionContext)arg0); | |
} | |
if (region.size() > 0) { | |
Iterator<PdxInstance> itr = region.values().iterator(); | |
while (itr.hasNext()) { | |
PdxInstance data = itr.next(); | |
double dist = euclidean(queryVec, ((String) data.getField(valueField)).trim()); | |
// Add a new serialized KNN entity to the queue. | |
PdxInstance knnEntity = KNNEntityToPdxInstance(cache, idField, data.getField(idField).toString(), | |
valueField, ((String) data.getField(valueField)).trim(), dist); | |
knn.offer(knnEntity); | |
// If more than K entities in the queue, remove the one | |
// with the maximum distance to the query point. | |
if (knn.size() > K) { | |
knn.poll(); | |
} | |
} | |
} | |
// Return K serialized KNN entities. Serialization and conversion of | |
// PriorityQueue to ArrayList is to enable JSON | |
// conversion, so that the function can be invoked through REST API. | |
arg0.getResultSender().lastResult(new ArrayList<>(knn)); | |
} | |
/** | |
* Euclidean distance. | |
* | |
* @param vec1 | |
* Space separated numerical string. | |
* @param vec2 | |
* Space separated numerical string. | |
* @return Euclidean distance. | |
*/ | |
public double euclidean(String vec1, String vec2) { | |
String[] v1 = vec1.split("\\s+"); | |
String[] v2 = vec2.split("\\s+"); | |
assert v1.length == v2.length : "euclidean(): Input vectors have different lengths."; | |
double res = 0; | |
for (int i = 0; i < v1.length; i++) { | |
res += Math.pow((Double.parseDouble(v1[i]) - Double.parseDouble(v2[i])), 2); | |
} | |
return Math.sqrt(res); | |
} | |
/** | |
* Comparator for serialized KNN entity comparison based on the distance to | |
* the query point. | |
*/ | |
public static Comparator<PdxInstance> distComparator = new Comparator<PdxInstance>() { | |
@Override | |
public int compare(PdxInstance a, PdxInstance b) { | |
double diff = (double) a.getField("dist") - (double) b.getField("dist"); | |
if(diff > 0) { | |
return -1; | |
} else if (diff < 0) { | |
return 1; | |
} else { | |
return 0; | |
} | |
} | |
}; | |
/** | |
* Create a PdxInstance for a KNN entity. | |
* | |
* @param gemfireCache | |
* GemFire cache. | |
* @param idField | |
* Name of the ID field of a KNN entity. | |
* @param id | |
* ID of a KNN entity. | |
* @param valueField | |
* Name of the value field of a KNN entity. | |
* @param value | |
* Vector representation of a KNN entity with the vector | |
* represented by a space separated string. | |
* @param dist | |
* Distance to the query point. | |
* @return PdxInstance of a KNN entity. | |
*/ | |
protected PdxInstance KNNEntityToPdxInstance(Cache gemfireCache, String idField, String id, String valueField, | |
String value, Double dist) { | |
PdxInstanceFactory pdxInstanceFactory = gemfireCache | |
.createPdxInstanceFactory("io.pivotal.ds.gemfire.entity.KNNEntity"); | |
pdxInstanceFactory.writeString(idField, id); | |
pdxInstanceFactory.writeString(valueField, value); | |
pdxInstanceFactory.writeDouble("dist", dist); | |
return pdxInstanceFactory.create(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment