Skip to content

Instantly share code, notes, and snippets.

@mikeosterlie
Last active December 24, 2015 07:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mikeosterlie/0d41249e46c9e7205ac9 to your computer and use it in GitHub Desktop.
Save mikeosterlie/0d41249e46c9e7205ac9 to your computer and use it in GitHub Desktop.
TestCase for transactional multithreaded commit problem.
package fpx.orientdb.commit;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import junit.framework.TestCase;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import com.orientechnologies.orient.client.remote.OServerAdmin;
import com.orientechnologies.orient.core.config.OGlobalConfiguration;
import com.orientechnologies.orient.core.db.graph.OGraphDatabase;
import com.orientechnologies.orient.core.db.graph.OGraphDatabasePool;
import com.orientechnologies.orient.core.db.record.OIdentifiable;
import com.orientechnologies.orient.core.id.ORID;
import com.orientechnologies.orient.core.id.ORecordId;
import com.orientechnologies.orient.core.iterator.ORecordIteratorClass;
import com.orientechnologies.orient.core.metadata.schema.OClass;
import com.orientechnologies.orient.core.metadata.schema.OType;
import com.orientechnologies.orient.core.record.impl.ODocument;
public class OrientCommitTest extends TestCase {
public static final String DB_URL = "memory:avltreetest";
public static final String DB_USER = "admin";
public static final String DB_PASSWORD = "admin";
public static final String DB_STORAGE = "plocal";
public static final String DB_TYPE = "graph";
private static final String TEST_CLASS = "ORIENT_COMMIT_TEST";
private static final String THREAD_ID = "ThreadId";
private static final String ID = "IdField";
private OGraphDatabasePool dbPool;
private String failureMessage = "";
private boolean isValidData;
private TestExecutor[] threads;
final int threadCount = 5;
final int maxSleepTime = 100;
final int maxOpCount = 6;
final int initialCacheSize = 10;
final AtomicInteger idGenerator = new AtomicInteger(1);
private static Random random = new Random();
public void setUp() {
this.dbPool = bootstrap();
buildSchemaAndSeed(this.dbPool.acquire());
this.isValidData = true;
}
public void tearDown() {
this.dbPool.close();
}
public void testWithTransaction() {
try {
System.setOut(new PrintStream(new File("log/CommitTestTransactional.txt")));
} catch (FileNotFoundException e) {
}
//set to run until it fails with the transaction set to true
executeTest(this.threadCount, true, this.maxSleepTime, this.maxOpCount, this.initialCacheSize, 0);
}
public void testSingleThreadWithTransaction() {
try {
System.setOut(new PrintStream(new File("log/CommitTestTransactionalSingleThread.txt")));
} catch (FileNotFoundException e) {
}
//set to run 5 minutes with the transaction set to true
executeTest(1, true, this.maxSleepTime, this.maxOpCount, this.initialCacheSize, 1);
}
public void testWithoutTransaction() {
try {
System.setOut(new PrintStream(new File("log/CommitTestNonTransactional.txt")));
} catch (FileNotFoundException e) {
}
//set to run for 1 min with transaction set to false
executeTest(this.threadCount, false, this.maxSleepTime, this.maxOpCount, this.initialCacheSize, 1);
}
/**
* If failure occurs, set its message and kill all other running threads
*/
public void setFailureMessage(final String message) {
this.isValidData = false;
this.failureMessage = message;
//exception reproduced - kill all threads
for (TestExecutor thread : this.threads) {
if(thread!=null){
thread.shutdown();
}
}
}
/**
* Get failure message
*/
public String getFailureMessage() {
return this.failureMessage;
}
/**
*
* @param threadCount - number of thread to run
* @param runInTransaction
* @param maxSleepTime
* @param maxOpCount
* @param initialCacheSize
* @param runtimeInMin
*/
private void executeTest(final int threadCount, final boolean runInTransaction, final int maxSleepTime, final int maxOpCount, final int initialCacheSize, final int runtimeInMin) {
CountDownLatch endLatch = new CountDownLatch(threadCount);
this.threads = new TestExecutor[threadCount];
for (int i = 0; i < threadCount; i++) {
this.threads[i] = new TestExecutor(this.dbPool.acquire(), runInTransaction, i, endLatch, maxSleepTime, maxOpCount);
System.out.println("Starting thread id: " + i);
this.threads[i].seedData(initialCacheSize);
new Thread(this.threads[i]).start();
}
if (runtimeInMin > 0) {
try {
Thread.sleep(60000 * runtimeInMin);
} catch (InterruptedException e) {
e.printStackTrace();
}
int successfulThreadCount = 0;
for (TestExecutor thread : this.threads) {
if (!thread.isShutdown()) {
++successfulThreadCount;
thread.shutdown();
}
}
//verify number or alive threads to number of thread specified to run
assertEquals(threadCount, successfulThreadCount);
}
// check if failure has occurred and fail test with message
try {
endLatch.await();
} catch (InterruptedException e) {
} finally{
if (!this.isValidData) {
fail(getFailureMessage());
}
}
}
class TestExecutor implements Runnable {
private int maxSleepTime;
private final CountDownLatch endLatch;
private boolean shutdown;
private int maxOpCount;
private boolean runInTransaction;
private OGraphDatabase db;
private final List<IdPair> cache;
private final int threadId;
public TestExecutor(final OGraphDatabase db, final boolean runInTransaction, final int threadId, final CountDownLatch endLatch, final int maxSleepTime, final int maxOpCount) {
this.endLatch = endLatch;
this.db = db;
this.maxSleepTime = maxSleepTime;
this.maxOpCount = maxOpCount;
this.shutdown = false;
this.runInTransaction = runInTransaction;
this.cache = new ArrayList<IdPair>();
this.threadId = threadId;
}
public void seedData(final int initialCacheSize) {
for (int i=0; i<initialCacheSize; i++) {
IdPair newNode = insertNewNode();
ORID recordId = newNode.getOrid();
Integer id = newNode.getCustomId();
this.cache.add(new IdPair(recordId, id));
}
}
public void run() {
try {
Thread.sleep((long) (Math.random() * this.maxSleepTime));
} catch (InterruptedException e) {
//swallow - irrelevant
}
try {
while(!this.shutdown) {
commitOperations();
}
} finally {
this.endLatch.countDown();
}
}
/**
* Perform a set of insert or delete operations (picked at random) with variable transaction flag
*/
private void commitOperations() {
try {
List<TempCacheObject> tempCache = new ArrayList<TempCacheObject>();
try {
//generate random operation list
List<Operation> operations = generateOperations(this.maxOpCount);
System.out.println("ThreadId: " + this.threadId + " Operations to execute are: " + operations);
if (this.runInTransaction) {
this.db.begin();
System.out.println("ThreadId: " + this.threadId + " Beginning transaction.");
}
for (Operation operation : operations) {
if (Operation.INSERT.equals(operation)) {
//perform insert operation
IdPair insertedNode = insertNewNode();
ORID insertId = insertedNode.getOrid();
if (!this.runInTransaction) {
System.out.println("ThreadId: " + this.threadId + " Inserting " + insertId);
}
//add inserted id to temp cache
tempCache.add(new TempCacheObject(operation, insertId, insertedNode.getCustomId()));
} else if (Operation.DELETE.equals(operation)) {
//get delete id
ORID deleteId = getRandomIdForThread();
if (deleteId != null) {
if (!this.runInTransaction) {
System.out.println("ThreadId: " + this.threadId + " Deleting " + deleteId);
}
//perform delete operation
Integer customId = deleteExistingNode(deleteId);
//add deleted id to temp cache
tempCache.add(new TempCacheObject(operation, deleteId, customId));
} else {
System.out.println("ThreadId: " + getName() + " no ids in database for thread to delete.");
}
}
}
if (this.runInTransaction) {
System.out.println("ThreadId: " + this.threadId + " Committing transaction. " + tempCache);
this.db.commit();
System.out.println("ThreadId: " + this.threadId + " transaction committed. " + tempCache);
}
} catch (Exception e) {
if (this.runInTransaction) {
this.db.rollback();
tempCache.clear();
System.out.println("ThreadId: " + this.threadId + " Rolling back transaction due to " + e.getClass().getSimpleName() + " " + e.getMessage());
e.printStackTrace(System.out);
}
}
// update permanent cache from temp cache
updateCache(tempCache);
//validate db against permanent cache
try {
validateCustomIdsAgainstDatabase();
} catch (Exception e) {
System.out.println(e.getMessage());
}
validateDatabase(this.cache);
} catch (Exception e) {
System.out.println("ThreadId: " + this.threadId + " threw a validation exception: " + e.getMessage());
e.printStackTrace(System.out);
//validation failed - set failure message
setFailureMessage(e.getMessage());
this.shutdown = true;
}
}
private void validateCustomIdsAgainstDatabase() throws Exception {
List<ODocument> recordsInDb = new ArrayList<ODocument>();
ORecordIteratorClass<ODocument> iterator = this.db.browseClass(TEST_CLASS);
while (iterator.hasNext()) {
recordsInDb.add(iterator.next());
}
for(IdPair cacheInstance : this.cache) {
Integer customId = cacheInstance.getCustomId();
boolean found = false;
for (ODocument doc : recordsInDb) {
if (doc.field(ID).equals(customId)) {
found = true;
break;
}
}
if (!found) {
throw new Exception("Custom id: " + customId + " exists in cache but was not found in db.");
}
}
}
public boolean isShutdown() {
return this.shutdown;
}
/**
* Verify that all ids in the permanent cache are in the db.
* Verify that all ids (for a given thread) in the db are in the permanent cache.
*/
private void validateDatabase(final List<IdPair> cache) throws Exception {
for(IdPair idPair : cache) {
ORID id = idPair.getOrid();
if (!isInDatabase(id)) {
throw new Exception("Insert issue: expected record " + id + " was not found in database.");
}
}
for(ODocument dbRecord : this.db.browseClass(TEST_CLASS)) {
if (Integer.valueOf(this.threadId).equals(dbRecord.field(THREAD_ID))) {
ORID dbId = dbRecord.getIdentity();
Integer customId = dbRecord.field(ID);
if (!cache.contains(new IdPair(dbId, customId))) {
throw new Exception("Delete issue: record id " + dbId + " for thread id " + this.threadId + " was not found in cache.");
}
}
}
}
/**
* Checks to see if an id for a given thread exist in the db.
*/
private boolean isInDatabase(final ORID id) throws Exception {
ODocument record = this.db.getRecord(id);
if (record != null) {
if(!Integer.valueOf(this.threadId).equals(record.field(THREAD_ID))){
return false;
}
}
return record != null;
}
/**
* Add id from the temp cache with insert operation to permanent cache.
* Remove id from permanent cache that has a delete operation in the temp cache.
* @param tempCache
*/
private void updateCache(final List<TempCacheObject> tempCache) {
for (TempCacheObject tempCacheObject : tempCache) {
ORID id = tempCacheObject.getOrientId();
Operation operation = tempCacheObject.getOperation();
Integer customId = tempCacheObject.getCustomId();
if (Operation.INSERT.equals(operation)) {
this.cache.add(new IdPair(id, customId));
} else if (Operation.DELETE.equals(operation)) {
this.cache.remove(new IdPair(id, customId));
}
}
}
/**
* Insert new node and create edge with the random node in the db.
*/
private IdPair insertNewNode() {
ODocument doc = this.db.createVertex(TEST_CLASS);
doc.field(THREAD_ID, Integer.valueOf(this.threadId));
Integer id = Integer.valueOf(OrientCommitTest.this.idGenerator.getAndIncrement());
doc.field(ID, id);
ORID randomId = getRandomIdForThread();
if (randomId != null) {
ODocument randomDoc = this.db.getRecord(randomId);
this.db.save(this.db.createEdge(doc, randomDoc));
}
ORID newRecordId = this.db.save(doc).getIdentity();
return new IdPair(newRecordId, id);
}
/**
* Delete all edges connected to given vertex and then delete vertex.
*/
private Integer deleteExistingNode(final ORID recordId) {
Set<OIdentifiable> inEdges = this.db.getInEdges(recordId);
Set<OIdentifiable> outEdges = this.db.getOutEdges(recordId);
Set<OIdentifiable> edges = new HashSet<OIdentifiable>();
edges.addAll(inEdges);
edges.addAll(outEdges);
for (OIdentifiable edge : edges) {
this.db.removeEdge(edge);
}
ORID id = new ORecordId(recordId);
ODocument doc = this.db.getRecord(id);
Integer customId = doc.field(ID);
this.db.removeVertex(new ORecordId(id));
return customId;
}
/**
* Get all of the ids from the db for that class for a given thread id. Return id from the list at random.
*/
private ORID getRandomIdForThread() {
List<ORID> idsInDb = new ArrayList<ORID>();
for(ODocument dbRecord : this.db.browseClass(TEST_CLASS)) {
if (Integer.valueOf(this.threadId).equals(dbRecord.field(THREAD_ID))) {
idsInDb.add(dbRecord.getIdentity());
}
}
int size = idsInDb.size();
if (size == 0) {
return null;
}
int index = random.nextInt(size);
return idsInDb.get(index);
}
private List<Operation> generateOperations(final int maxOpCount) {
List<Operation> operationsList = new ArrayList<Operation>();
int opCount = (int) (Math.random() * maxOpCount / 2 + maxOpCount / 2);
for (int index = 0; index < opCount; index++) {
Operation op = Operation.getRandom();
operationsList.add(op);
}
return operationsList;
}
private void shutdown() {
this.shutdown = true;
}
private class TempCacheObject {
private Operation operation;
private ORID orientId;
private Integer customId;
public TempCacheObject(final Operation operation, final ORID orientId, final Integer customId) {
this.operation = operation;
this.orientId = orientId;
this.customId = customId;
}
public Operation getOperation() {
return this.operation;
}
public ORID getOrientId() {
return this.orientId;
}
public Integer getCustomId() {
return this.customId;
}
public String toString() {
StringBuilder stringObject = new StringBuilder();
stringObject.append("Operation:").append(this.operation).append(", ORID:").append(this.orientId).append(", CustomId:").append(this.customId);
return stringObject.toString();
}
}
}
/**
* Defines two operations types
*/
private static enum Operation {
INSERT, DELETE;
/**
* Picks operation at random
*/
public static Operation getRandom() {
if (0.55 > Math.random()) {
return INSERT;
} else {
return DELETE;
}
}
}
private static class IdPair {
private ORID orid;
private Integer customId;
public IdPair(final ORID orid, final Integer customId) {
super();
this.orid = orid;
this.customId = customId;
}
public ORID getOrid() {
return this.orid;
}
public Integer getCustomId() {
return this.customId;
}
@Override
public boolean equals(final Object obj) {
if (!(obj instanceof IdPair)) {
return false;
}
IdPair idPair = (IdPair) obj;
if (!idPair.orid.equals(this.orid)) {
return false;
}
if (!idPair.customId.equals(this.customId)) {
return false;
}
return true;
}
}
/**
* Create db
*/
public OGraphDatabasePool bootstrap() {
OGlobalConfiguration.CACHE_LEVEL1_ENABLED.setValue(Boolean.FALSE);
if (DB_URL.startsWith("remote")) {
OServerAdmin server;
try {
server = new OServerAdmin(DB_URL).connect(DB_USER, DB_PASSWORD);
if (server.existsDatabase(DB_STORAGE)) {
server.dropDatabase(DB_STORAGE);
}
server.createDatabase("graph", DB_STORAGE);
server.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
} else {
OGraphDatabase db = new OGraphDatabase(DB_URL);
if (db.exists()) {
db.open(DB_USER, DB_PASSWORD);
db.drop();
}
db.create();
db.close();
}
OGraphDatabasePool pool = new OGraphDatabasePool(DB_URL, DB_USER, DB_PASSWORD);
pool.setup(this.threadCount + 1, 2 * this.threadCount);
return pool;
}
/**
* Create schema that has one class and one field
*/
public void buildSchemaAndSeed(final OGraphDatabase db) {
OClass nodeClass = db.createVertexType(TEST_CLASS);
nodeClass.createProperty(THREAD_ID, OType.INTEGER).setMandatory(true).setNotNull(true);
nodeClass.createProperty(ID, OType.INTEGER).setMandatory(true).setNotNull(true);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment