Skip to content

Instantly share code, notes, and snippets.

@maxdemarzi
Created September 12, 2017 17:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxdemarzi/ee3e3be8fa10f4e25a8ba9df31a629ac to your computer and use it in GitHub Desktop.
Save maxdemarzi/ee3e3be8fa10f4e25a8ba9df31a629ac to your computer and use it in GitHub Desktop.

CALL com.maxdemarzi.similarity(0.90, 100)

package com.maxdemarzi;
import com.maxdemarzi.results.StringResult;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.Log;
import org.neo4j.procedure.*;
import java.util.stream.Stream;
public class Similarity {
@Context
public GraphDatabaseAPI db;
@Context
public Log log;
@Description("com.maxdemarzi.similarity() ")
@Procedure(name = "com.maxdemarzi.similarity", mode = Mode.WRITE)
public Stream<StringResult> Similarity(@Name("minimum") Double min, @Name("limit") Number limit) throws InterruptedException {
Thread t1 = new Thread(new SimilarityRunnable(min, limit.intValue(), db, log));
t1.start();
t1.join();
return Stream.of(new StringResult("Similarities were calculated."));
}
}
package com.maxdemarzi;
import com.maxdemarzi.schema.Labels;
import com.maxdemarzi.schema.RelationshipTypes;
import org.neo4j.graphdb.*;
import org.neo4j.helpers.collection.Pair;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.Log;
import java.util.*;
import java.util.concurrent.TimeUnit;
import static java.util.Collections.reverseOrder;
public class SimilarityRunnable implements Runnable {
private static final int TRANSACTION_LIMIT = 10;
private static GraphDatabaseAPI db;
private Double min;
private Integer limit;
private static Log log;
public SimilarityRunnable (Double min, Integer limit, GraphDatabaseAPI db, Log log) {
this.min = min;
this.limit = limit;
this.db = db;
this.log = log;
}
@Override
public void run() {
long start = System.nanoTime();
// Get all the Customer Accounts that have been divested
ArrayList<Node> divestedAccounts = new ArrayList<>();
try (Transaction tx = db.beginTx()) {
ResourceIterator<Node> iterator = db.findNodes(Labels.divested);
while (iterator.hasNext()) {
divestedAccounts.add(iterator.next());
}
tx.success();
}
// For each divested account find similar accounts
Transaction tx = db.beginTx();
int count = 0;
try {
for (Node account : divestedAccounts) {
count++;
Map<Node, List<Double>> mine = new HashMap<>();
Map<Node, List<Double>> theirs = new HashMap<>();
for (Relationship r : account.getRelationships(Direction.OUTGOING, RelationshipTypes.TAGGED)) {
Double weight = (Double)r.getProperty("weight");
Node vector = r.getEndNode();
for (Relationship r2 : vector.getRelationships(Direction.INCOMING, RelationshipTypes.TAGGED)) {
Node account2 = r2.getStartNode();
if (!account.equals(account2)) {
addVectorNodes(mine, account2, weight);
addVectorNodes(theirs, account2, (Double)r2.getProperty("weight"));
}
}
}
ArrayList<Pair<Node, Double>> top = new ArrayList<>();
for (Map.Entry<Node, List<Double>> entry : mine.entrySet()) {
double score = calculateSimilarity(entry.getValue(), theirs.get(entry.getKey()));
if (score >= min) {
top.add(Pair.of(entry.getKey(), score));
}
}
top.sort(Comparator.comparing(m -> (Double) m.other(), reverseOrder()));
for (Pair<Node, Double> calculation : top.subList(0, Math.min(top.size(), limit))){
Relationship similar = account.createRelationshipTo(calculation.first(), RelationshipTypes.SIMILAR);
similar.setProperty("similarity", calculation.other());
}
if (count % TRANSACTION_LIMIT == 0) {
tx.success();
tx.close();
tx = db.beginTx();
log.info("Committing similarity work after " + count + " in " + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - start) + " seconds since starting.");
}
}
tx.success();
} finally {
tx.close();
}
long timeTaken = TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - start);
log.info("Similarity calculated in " + timeTaken + " Seconds");
}
private void addVectorNodes(Map<Node, List<Double>> multimap, Node key, Double value) {
List<Double> list = multimap.computeIfAbsent(key, k -> new ArrayList<>());
list.add(value);
}
private double calculateSimilarity(List<Double> vector1, List<Double> vector2) {
double dotProduct = 0d;
double xLength = 0d;
double yLength = 0d;
for (int i = 0; i < vector1.size(); i++) {
dotProduct += vector1.get(i) * vector2.get(i);
xLength += vector1.get(i) * vector1.get(i);
yLength += vector2.get(i) * vector2.get(i);
}
xLength = Math.sqrt(xLength);
yLength = Math.sqrt(yLength);
return dotProduct / (xLength * yLength);
}
}
package com.maxdemarzi.results;
/**
* @author mh
* @since 26.02.16
*/
public class StringResult {
public final static StringResult EMPTY = new StringResult(null);
public final String value;
public StringResult(String value) {
this.value = value;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment