Skip to content

Instantly share code, notes, and snippets.

@mad
Created July 3, 2019 15:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mad/4230d56d73587c11dcf48eb9463ce751 to your computer and use it in GitHub Desktop.
Save mad/4230d56d73587c11dcf48eb9463ce751 to your computer and use it in GitHub Desktop.
package org.janusgraph;
import org.janusgraph.core.JanusGraph;
import org.janusgraph.core.JanusGraphFactory;
import org.janusgraph.testutil.MemoryAssess;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.io.TempDir;
import java.io.File;
import java.lang.ref.Reference;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
class TransactionTest {
private final int numTx = 100;
private final int batchSize = 5;
private ExecutorService executorService = Executors.newFixedThreadPool(10);
@RepeatedTest(5)
@Order(2)
void autoTx(@TempDir File path) throws InterruptedException, ExecutionException {
System.out.println("Auto");
MemoryAssess memoryAssess = new MemoryAssess();
memoryAssess.start();
JanusGraph graph = JanusGraphFactory.open("berkeleyje:" + path.getAbsolutePath());
for (int i = 0; i < numTx; i++) {
CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
for (int k = 0; k < batchSize; k++) {
graph.traversal()
.addV("V")
.property("p1", RandomStringUtils.random(10))
.property("p2", RandomStringUtils.random(10))
.as("a")
.addV("V")
.property("p1", RandomStringUtils.random(10))
.property("p2", RandomStringUtils.random(10))
.addE("E").to("a").property("p3", RandomStringUtils.random(10))
.iterate();
graph.traversal().V().toList();
}
graph.tx().commit();
graph.tx().close();
}, executorService);
future.get();
}
System.out.println(memoryAssess.end());
graph.close();
checkThreadLocalsForLeaks();
executorService.shutdown();
}
private void checkThreadLocalsForLeaks() {
Thread[] threads = getThreads();
try {
// Make the fields in the Thread class that store ThreadLocals
// accessible
Field threadLocalsField =
Thread.class.getDeclaredField("threadLocals");
threadLocalsField.setAccessible(true);
Field inheritableThreadLocalsField =
Thread.class.getDeclaredField("inheritableThreadLocals");
inheritableThreadLocalsField.setAccessible(true);
// Make the underlying array of ThreadLoad.ThreadLocalMap.Entry objects
// accessible
Class<?> tlmClass = Class.forName("java.lang.ThreadLocal$ThreadLocalMap");
Field tableField = tlmClass.getDeclaredField("table");
tableField.setAccessible(true);
Method expungeStaleEntriesMethod = tlmClass.getDeclaredMethod("expungeStaleEntries");
expungeStaleEntriesMethod.setAccessible(true);
for (int i = 0; i < threads.length; i++) {
Object threadLocalMap;
if (threads[i] != null) {
// Clear the first map
threadLocalMap = threadLocalsField.get(threads[i]);
if (null != threadLocalMap) {
expungeStaleEntriesMethod.invoke(threadLocalMap);
checkThreadLocalMapForLeaks(threadLocalMap, tableField);
}
// Clear the second map
threadLocalMap = inheritableThreadLocalsField.get(threads[i]);
if (null != threadLocalMap) {
expungeStaleEntriesMethod.invoke(threadLocalMap);
checkThreadLocalMapForLeaks(threadLocalMap, tableField);
}
}
}
} catch (Throwable t) {
}
}
private Thread[] getThreads() {
// Get the current thread group
ThreadGroup tg = Thread.currentThread().getThreadGroup();
// Find the root thread group
try {
while (tg.getParent() != null) {
tg = tg.getParent();
}
} catch (SecurityException se) {
}
int threadCountGuess = tg.activeCount() + 50;
Thread[] threads = new Thread[threadCountGuess];
int threadCountActual = tg.enumerate(threads);
// Make sure we don't miss any threads
while (threadCountActual == threadCountGuess) {
threadCountGuess *= 2;
threads = new Thread[threadCountGuess];
// Note tg.enumerate(Thread[]) silently ignores any threads that
// can't fit into the array
threadCountActual = tg.enumerate(threads);
}
return threads;
}
/**
* Analyzes the given thread local map object. Also pass in the field that
* points to the internal table to save re-calculating it on every
* call to this method.
*/
private void checkThreadLocalMapForLeaks(Object map,
Field internalTableField) throws IllegalAccessException,
NoSuchFieldException {
if (map != null) {
Object[] table = (Object[]) internalTableField.get(map);
if (table != null) {
for (int j = 0; j < table.length; j++) {
Object obj = table[j];
if (obj != null) {
boolean potentialLeak = false;
// Check the key
Object key = ((Reference<?>) obj).get();
if (this.equals(key) || loadedByThisOrChild(key)) {
potentialLeak = true;
}
// Check the value
Field valueField =
obj.getClass().getDeclaredField("value");
valueField.setAccessible(true);
Object value = valueField.get(obj);
if (this.equals(value) || loadedByThisOrChild(value)) {
potentialLeak = true;
}
if (potentialLeak) {
Object[] args = new Object[5];
args[0] = "TEST";
if (key != null) {
args[1] = getPrettyClassName(key.getClass());
try {
args[2] = key.toString();
} catch (Exception e) {
}
}
if (value != null) {
args[3] = getPrettyClassName(value.getClass());
try {
args[4] = value.toString();
} catch (Exception e) {
}
}
System.out.println(Arrays.toString(args));
}
}
}
}
}
}
private String getPrettyClassName(Class<?> clazz) {
String name = clazz.getCanonicalName();
if (name == null) {
name = clazz.getName();
}
return name;
}
private boolean loadedByThisOrChild(Object o) {
return true;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment