Skip to content

Instantly share code, notes, and snippets.

Created December 5, 2014 09:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/4c83f2962b57fce4c3df to your computer and use it in GitHub Desktop.
Save anonymous/4c83f2962b57fce4c3df to your computer and use it in GitHub Desktop.
Cassandra CAS Test
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Cassandra;
using Xunit;
namespace ConcurrentKvsPoc
{
public class ConcurrentUpdatesTest : CassandraTest
{
private static readonly Random Random = new Random();
private int abortedUpdates;
private int partialWriteTimeouts;
private int totalWriteTimeouts;
[Fact]
public async Task ResetTable()
{
await Session.ExecuteAsync(GetQuorumStatement("DROP TABLE IF EXISTS objects;"));
await Session.ExecuteAsync(GetQuorumStatement("CREATE TABLE objects (key text, version int, PRIMARY KEY(key));"));
}
[Fact]
public async Task SupportsOptimisticConcurrency()
{
string test = "test-" + Guid.NewGuid().ToString().Substring(4);
const int itemCount = 100;
const int updatesPerItem = 10;
const int expectedUpdateCount = itemCount*updatesPerItem;
Trace.WriteLine(string.Format("Running test with {0} items and {1} updates per item.", itemCount, updatesPerItem));
Trace.WriteLine("");
Task<PreparedStatement> prepareInsertStatement = Session.PrepareAsync("INSERT INTO objects (key, version) VALUES (?, 0) IF NOT EXISTS;");
Task<PreparedStatement> prepareSelectStatement = Session.PrepareAsync("SELECT version FROM objects WHERE key = ?;");
Task<PreparedStatement> prepareUpdateStatement = Session.PrepareAsync("UPDATE objects SET version = ? WHERE key = ? IF version = ?;");
await InsertTestItems(itemCount, test, await prepareInsertStatement);
PreparedStatement selectStatement = await prepareSelectStatement;
PreparedStatement updateStatement = await prepareUpdateStatement;
await UpdateItems(itemCount, updatesPerItem, test, selectStatement, updateStatement);
Trace.WriteLine("Number of updates: " + expectedUpdateCount);
Trace.WriteLine("Number of aborted updates due to concurrency: " + abortedUpdates);
Trace.WriteLine("Number of total write timeouts: " + totalWriteTimeouts);
Trace.WriteLine("Number of partial write timeouts: " + partialWriteTimeouts);
await CheckResults(itemCount, updatesPerItem, test, selectStatement);
}
private async Task Update(string key, PreparedStatement selectStatement, PreparedStatement updateStatement)
{
bool done = false;
// try update (increase version) until it succeeds
while (!done)
{
// get current version
TestItem item = null;
while (item == null)
item = await GetItem(key, selectStatement);
try
{
// update version using lightweight transaction
done = await CompareAndSet(key, item.Version, updateStatement);
// lightweight transaction (CAS) failed, because compare failed --> simply not updated
if (!done)
Interlocked.Increment(ref abortedUpdates);
}
catch (WriteTimeoutException wte)
{
// partial write timeout (some have been updated, so all must be eventually updated, because it is a CAS operation)
if (wte.ReceivedAcknowledgements > 0)
{
Interlocked.Increment(ref partialWriteTimeouts);
done = true;
}
else
// complete write timeout --> unsure about this one...
Interlocked.Increment(ref totalWriteTimeouts);
}
}
}
private async Task<bool> CompareAndSet(string key, int currrentCount, PreparedStatement updateStatement)
{
IStatement statement = updateStatement
.Bind(currrentCount + 1, key, currrentCount)
.SetSerialConsistencyLevel(ConsistencyLevel.Serial)
.SetConsistencyLevel(ConsistencyLevel.Quorum);
RowSet result = await Session.ExecuteAsync(statement);
Row row = result.GetRows().SingleOrDefault();
if (row == null)
throw new Exception("No row in update result.");
return row.GetValue<bool>("[applied]");
}
private async Task UpdateItems(int itemCount, int updatesPerItem, string test, PreparedStatement selectStatement, PreparedStatement updateStatement)
{
IEnumerable<int> updateOrder = GetRandomUpdateOrder(itemCount, updatesPerItem);
//Trace.WriteLine("Update order: " + string.Join(", ", updateOrder));
IEnumerable<Task> updateTasks =
from item in updateOrder
select Update(GetKey(item, test), selectStatement, updateStatement);
await Task.WhenAll(updateTasks.ToArray());
}
private async Task CheckResults(int itemCount, int updatesPerItem, string test, PreparedStatement selectStatement)
{
IEnumerable<Task<TestItem>> tasks = from i in Range(itemCount)
select GetItem(i, test, selectStatement, ConsistencyLevel.Quorum);
TestItem[] items = await Task.WhenAll(tasks.ToArray());
int writeCount = items.Sum(i => i.Version);
int expectedWriteCount = itemCount*updatesPerItem;
int lostWriteCount = expectedWriteCount - writeCount;
if (lostWriteCount != 0)
{
Trace.WriteLine("");
Trace.WriteLine(string.Format("LOST WRITES: {0} (or {1:P})", lostWriteCount, lostWriteCount/(double) expectedWriteCount));
}
DisplayResults(items);
Assert.Equal(0, lostWriteCount);
//lostWriteCount.Should().Be(0, "no writes should have been lost");
}
private static void DisplayResults(IEnumerable<TestItem> items)
{
IEnumerable<IGrouping<int, TestItem>> groups = from item in items
group item by item.Version
into g
select g;
var query = from g in groups
let updateCount = g.Key
orderby updateCount descending
select new
{
UpdateCount = updateCount,
ItemCount = g.Count()
};
var results = query.ToArray();
Trace.WriteLine("");
Trace.WriteLine("Results: ");
Trace.WriteLine("");
const string updatesHeader = "Updates";
const string itemCountHeader = "Item version";
int col1Width = updatesHeader.Length;
int col2Width = itemCountHeader.Length;
string formatString = string.Format("{{0,{0}}} | {{1,{1}}}", col1Width, col2Width);
Trace.WriteLine(string.Format("{0} | {1}", updatesHeader, itemCountHeader));
foreach (var result in results)
Trace.WriteLine(string.Format(formatString, result.UpdateCount, result.ItemCount));
}
private string GetKey(int i, string test)
{
return string.Format("{0}\\{1}", test, i);
}
private async Task<TestItem> GetItem(int i, string test, PreparedStatement selectStatement, ConsistencyLevel consistencyLevel = ConsistencyLevel.One)
{
string key = GetKey(i, test);
return await GetItem(key, selectStatement, consistencyLevel);
}
private async Task<TestItem> GetItem(string key, PreparedStatement selectStatement, ConsistencyLevel consistencyLevel = ConsistencyLevel.One)
{
IStatement boundStatement = selectStatement
.Bind(key)
.SetConsistencyLevel(consistencyLevel);
RowSet result = await Session.ExecuteAsync(boundStatement);
Row row = result.GetRows().SingleOrDefault();
if (row == null)
return null;
var version = row.GetValue<int>("version");
return new TestItem(version);
}
private static IEnumerable<int> GetRandomUpdateOrder(int itemCount, int updatesPerItem)
{
IEnumerable<int> updateQuery = from i in Range(itemCount)
from j in Range(updatesPerItem)
select i;
List<int> updates = updateQuery.ToList();
var randomOrder = new List<int>();
for (int i = 0; i < itemCount*updatesPerItem; i++)
{
int index = Random.Next(updates.Count);
int item = updates[index];
updates.RemoveAt(index);
randomOrder.Add(item);
}
return randomOrder;
}
private async Task InsertTestItems(int itemCount, string test, PreparedStatement insertStatement)
{
IEnumerable<Task> insertItems = from i in Range(itemCount)
select InsertItem(test, i, insertStatement);
await Task.WhenAll(insertItems.ToArray());
}
private Task InsertItem(string test, int i, PreparedStatement insertStatement)
{
BoundStatement boundStatement = insertStatement.Bind(GetKey(i, test));
return Session.ExecuteAsync(boundStatement);
}
private static IEnumerable<int> Range(int itemCount)
{
return Enumerable.Range(0, itemCount);
}
private static IStatement GetQuorumStatement(string cql)
{
return new SimpleStatement(cql).SetConsistencyLevel(ConsistencyLevel.Quorum);
}
private class TestItem
{
private readonly int version;
public TestItem(int version)
{
this.version = version;
}
public int Version
{
get { return version; }
}
}
}
public abstract class CassandraTest : IDisposable
{
protected readonly ISession Session;
private readonly Cluster cluster;
protected CassandraTest()
{
cluster = Cluster.Builder()
.AddContactPoint("10.43.192.40")
.AddContactPoint("10.43.192.41")
.AddContactPoint("10.43.192.42")
.Build();
Session = cluster.Connect("kvs_poc");
}
void IDisposable.Dispose()
{
if (Session != null)
Session.Dispose();
if (cluster != null)
cluster.Dispose();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment