Skip to content

Instantly share code, notes, and snippets.

@salessandri
Created November 2, 2016 05:33
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save salessandri/41eebffd917637cce6231d3eb5c0b374 to your computer and use it in GitHub Desktop.
Save salessandri/41eebffd917637cce6231d3eb5c0b374 to your computer and use it in GitHub Desktop.
Proof of concept of how to design the architecture to handle transactions in C++
#include <atomic>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <mutex>
#include <thread>
#include <boost/optional.hpp>
#include <sqlite3.h>
// Interfaces
using ScopedDbConnection = std::unique_ptr<sqlite3, std::function<void(sqlite3*)>>;
class DbConnectionManager {
public:
virtual ScopedDbConnection getConnection() = 0;
};
class AbortTransaction {};
class TransactionAborted {};
class TransactionManager {
public:
virtual void performInTransaction(const std::function<void()>& f) = 0;
};
struct Entity;
class EntityRepoException {};
class EntityRepo {
public:
virtual void save(const Entity& entity) = 0;
virtual boost::optional<Entity> findById(uint64_t id) = 0;
};
// Implementations
class ConcreteConnectionManager : public DbConnectionManager {
public:
ConcreteConnectionManager(const std::string& dbFile);
ScopedDbConnection getConnection() override;
private:
std::string _dbFile;
};
ConcreteConnectionManager::ConcreteConnectionManager(const std::string& dbFile) :
_dbFile(dbFile)
{}
ScopedDbConnection ConcreteConnectionManager::getConnection()
{
sqlite3* connection;
int openResult = sqlite3_open_v2(
_dbFile.c_str(),
&connection,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX,
NULL
);
if (openResult != SQLITE_OK) {
std::cerr << "[-] Could not open database: " << openResult << std::endl;
sqlite3_close_v2(connection);
throw std::runtime_error("Database could not be opened");
}
return ScopedDbConnection(connection, [](sqlite3* c) { sqlite3_close_v2(c); });
}
class ConcreteTransactionManager : public DbConnectionManager, public TransactionManager {
public:
ConcreteTransactionManager(ConcreteConnectionManager& connectionManager);
ScopedDbConnection getConnection() override;
void performInTransaction(const std::function<void()>& f) override;
private:
struct TransactionInfo {
ScopedDbConnection dbConnection;
uint32_t count;
};
TransactionInfo& setupTransaction(const std::thread::id& threadId);
void abortTransaction(TransactionInfo& transactionInfo);
void commitTransaction(TransactionInfo& transactionInfo);
ConcreteConnectionManager& _connectionManager;
std::map<std::thread::id, TransactionInfo> _currentTransactions;
std::mutex _currentTransactionsMutex;
};
ConcreteTransactionManager::ConcreteTransactionManager(
ConcreteConnectionManager& connectionManager) :
_connectionManager(connectionManager)
{}
ScopedDbConnection ConcreteTransactionManager::getConnection()
{
auto threadId = std::this_thread::get_id();
std::lock_guard<std::mutex> _(_currentTransactionsMutex);
auto it = _currentTransactions.find(threadId);
if (it == _currentTransactions.end()) {
return _connectionManager.getConnection();
}
return ScopedDbConnection(it->second.dbConnection.get(), [](sqlite3* c) {});
}
void ConcreteTransactionManager::abortTransaction(TransactionInfo& transactionInfo)
{
int abortResult = sqlite3_exec(
transactionInfo.dbConnection.get(),
"ROLLBACK TRANSACTION",
nullptr,
nullptr,
nullptr
);
// If a transaction cannot be aborted, something bad happened.
assert(abortResult == SQLITE_OK);
auto threadId = std::this_thread::get_id();
{
std::lock_guard<std::mutex> _(_currentTransactionsMutex);
_currentTransactions.erase(threadId);
}
}
void ConcreteTransactionManager::commitTransaction(TransactionInfo& transactionInfo)
{
int commitResult = sqlite3_exec(
transactionInfo.dbConnection.get(),
"COMMIT TRANSACTION",
nullptr,
nullptr,
nullptr
);
auto threadId = std::this_thread::get_id();
if (commitResult == SQLITE_OK) {
std::lock_guard<std::mutex> _(_currentTransactionsMutex);
_currentTransactions.erase(threadId);
return;
}
abortTransaction(transactionInfo);
throw TransactionAborted();
}
ConcreteTransactionManager::TransactionInfo&
ConcreteTransactionManager::setupTransaction(const std::thread::id& threadId)
{
std::lock_guard<std::mutex> _(_currentTransactionsMutex);
if (_currentTransactions.count(threadId) == 0) {
TransactionInfo transactionInfo{_connectionManager.getConnection(), 1};
int transactionStartResult = sqlite3_exec(
transactionInfo.dbConnection.get(),
"BEGIN TRANSACTION",
nullptr,
nullptr,
nullptr
);
if (transactionStartResult != SQLITE_OK) {
throw std::runtime_error("Could not start transaction");
}
auto it = _currentTransactions.emplace(threadId, std::move(transactionInfo)).first;
return it->second;
}
else {
TransactionInfo& transactionInfo = _currentTransactions[threadId];
++transactionInfo.count;
return transactionInfo;
}
}
void ConcreteTransactionManager::performInTransaction(const std::function<void()>& f)
{
auto threadId = std::this_thread::get_id();
TransactionInfo& transactionInfo = setupTransaction(threadId);
try {
f();
}
catch (AbortTransaction&) {
if (--transactionInfo.count > 0) {
throw;
}
abortTransaction(transactionInfo);
return;
}
catch (...) {
if (--transactionInfo.count > 0) {
throw;
}
abortTransaction(transactionInfo);
throw;
}
if (--transactionInfo.count > 0) {
return;
}
commitTransaction(transactionInfo);
}
struct Entity {
int64_t id;
int64_t balance;
};
class ConcreteEntityRepo : public EntityRepo {
public:
ConcreteEntityRepo(DbConnectionManager& connectionManager);
void save(const Entity& entity) override;
boost::optional<Entity> findById(uint64_t id) override;
private:
DbConnectionManager& _connectionManager;
};
ConcreteEntityRepo::ConcreteEntityRepo(DbConnectionManager& connectionManager) :
_connectionManager(connectionManager)
{}
void ConcreteEntityRepo::save(const Entity& entity)
{
const char* sql = "INSERT OR REPLACE INTO entities (id, balance) VALUES (?, ?)";
ScopedDbConnection dbConnection = _connectionManager.getConnection();
sqlite3_stmt* rawStmt;
int stmtCreationResult = sqlite3_prepare_v2(
dbConnection.get(),
sql,
-1,
&rawStmt,
nullptr
);
if (stmtCreationResult != SQLITE_OK) {
throw EntityRepoException();
}
auto stmtDeleter = [](sqlite3_stmt* p) { sqlite3_finalize(p); };
std::unique_ptr<sqlite3_stmt, decltype(stmtDeleter)> stmt(rawStmt, stmtDeleter);
if (sqlite3_bind_int64(stmt.get(), 1, entity.id) != SQLITE_OK) {
throw EntityRepoException();
}
if (sqlite3_bind_int64(stmt.get(), 2, entity.balance) != SQLITE_OK) {
throw EntityRepoException();
}
if (sqlite3_step(stmt.get()) != SQLITE_DONE) {
throw EntityRepoException();
}
}
boost::optional<Entity> ConcreteEntityRepo::findById(uint64_t id)
{
const char* sql = "SELECT id, balance FROM entities WHERE id = ?";
ScopedDbConnection dbConnection = _connectionManager.getConnection();
sqlite3_stmt* rawStmt;
int stmtCreationResult = sqlite3_prepare_v2(
dbConnection.get(),
sql,
-1,
&rawStmt,
nullptr
);
if (stmtCreationResult != SQLITE_OK) {
throw EntityRepoException();
}
auto stmtDeleter = [](sqlite3_stmt* p) { sqlite3_finalize(p); };
std::unique_ptr<sqlite3_stmt, decltype(stmtDeleter)> stmt(rawStmt, stmtDeleter);
if (sqlite3_bind_int64(stmt.get(), 1, id) != SQLITE_OK) {
throw EntityRepoException();
}
int stepResult = sqlite3_step(stmt.get());
if (stepResult == SQLITE_ROW) {
Entity result{sqlite3_column_int64(stmt.get(), 0), sqlite3_column_int64(stmt.get(), 1)};
return result;
}
else if (stepResult == SQLITE_DONE) {
return boost::none;
}
else {
throw EntityRepoException();
}
}
void noTransaction(EntityRepo& repo)
{
Entity e{1, 1000};
repo.save(e);
std::cout << "[+] No transaction: finished" << std::endl;
}
void inTransactionNoOneElse(TransactionManager& transactionManager, EntityRepo& entityRepo)
{
transactionManager.performInTransaction([&]() {
Entity e{2, 1000};
entityRepo.save(e);
Entity e2 = *entityRepo.findById(2);
e2.balance += 1000;
entityRepo.save(e2);
});
std::cout << "[+] In transaction with no one else" << std::endl;
}
void transactionAborted(TransactionManager& transactionManager, EntityRepo& entityRepo)
{
transactionManager.performInTransaction([&]() {
Entity e{3, 1000};
entityRepo.save(e);
Entity e2 = *entityRepo.findById(3);
e2.balance += 1000;
entityRepo.save(e2);
throw AbortTransaction();
});
std::cout << "[+] Abort transaction example: Finished" << std::endl;
}
void simultaneousTransactions(TransactionManager& transactionManager, EntityRepo& entityRepo)
{
std::atomic<bool> saveExecuted{false};
std::atomic<bool> findExecuted{false};
std::thread t1([&]() {
transactionManager.performInTransaction([&]() {
while (!saveExecuted) {}
boost::optional<Entity> e = entityRepo.findById(4);
findExecuted = true;
if (e) {
std::cout << "[-] Transaction behavior violated!" << std::endl;
}
else {
std::cout << "[+] Simultaneous non-interfering transactions: success" << std::endl;
}
});
});
std::thread t2([&]() {
transactionManager.performInTransaction([&]() {
Entity e{4, 1000};
entityRepo.save(e);
saveExecuted = true;
while (!findExecuted) {}
});
});
t1.join();
t2.join();
std::cout << "[+] Simultaneous transaction example: Finished" << std::endl;
}
void simultaneousConflictingTransactions(
TransactionManager& transactionManager,
EntityRepo& entityRepo)
{
std::atomic<bool> saveExecuted{false};
std::atomic<bool> findExecuted{false};
std::thread t1([&]() {
transactionManager.performInTransaction([&]() {
while (!saveExecuted) {}
boost::optional<Entity> e = entityRepo.findById(5);
if (e) {
std::cout << "[-] Transaction behavior violated!" << std::endl;
}
else {
try {
entityRepo.save(Entity{5, 2000});
}
catch (EntityRepoException&) {
findExecuted = true;
std::cout << "[+] Conflicting operations: correct" << std::endl;
throw AbortTransaction();
}
}
});
});
std::thread t2([&]() {
transactionManager.performInTransaction([&]() {
Entity e{5, 1000};
entityRepo.save(e);
saveExecuted = true;
while (!findExecuted) {}
});
});
t1.join();
t2.join();
std::cout << "[+] Simultaneous transaction example: Finished" << std::endl;
}
void simultaneousConflictingTransactions2(
TransactionManager& transactionManager,
EntityRepo& entityRepo)
{
std::atomic<bool> saveExecuted{false};
std::atomic<bool> findExecuted{false};
std::atomic<bool> transactionCommitted{false};
std::thread t1([&]() {
transactionManager.performInTransaction([&]() {
while (!saveExecuted) {}
boost::optional<Entity> e = entityRepo.findById(6);
findExecuted = true;
if (e) {
std::cout << "[-] Transaction behavior violated!" << std::endl;
throw AbortTransaction();
}
while (!transactionCommitted) {}
try {
entityRepo.save(Entity{6, 2000});
}
catch (EntityRepoException&) {
std::cout << "[+] Conflicting operations: correct" << std::endl;
throw AbortTransaction();
}
});
});
std::thread t2([&]() {
transactionManager.performInTransaction([&]() {
Entity e{6, 1000};
entityRepo.save(e);
saveExecuted = true;
while (!findExecuted) {}
});
transactionCommitted = true;
});
t1.join();
t2.join();
std::cout << "[+] Simultaneous transaction example: Finished" << std::endl;
}
void nestedTransaction(TransactionManager& transactionManager, EntityRepo& entityRepo)
{
transactionManager.performInTransaction([&]() {
Entity e{7, 1000};
entityRepo.save(e);
transactionManager.performInTransaction([&]() {
boost::optional<Entity> e2 = entityRepo.findById(7);
assert(e2);
e2->balance += 1000;
entityRepo.save(*e2);
});
});
std::cout << "[+] Nested transaction: Finished" << std::endl;
}
void nestedTransactionOuterAbort(TransactionManager& transactionManager, EntityRepo& entityRepo)
{
transactionManager.performInTransaction([&]() {
Entity e{8, 1000};
entityRepo.save(e);
transactionManager.performInTransaction([&]() {
boost::optional<Entity> e2 = entityRepo.findById(8);
assert(e2);
e2->balance += 1000;
entityRepo.save(*e2);
});
throw AbortTransaction();
});
std::cout << "[+] Nested transaction, outer aborts: Finished" << std::endl;
}
void nestedTransactionInnerAbort(TransactionManager& transactionManager, EntityRepo& entityRepo)
{
transactionManager.performInTransaction([&]() {
Entity e{9, 1000};
entityRepo.save(e);
transactionManager.performInTransaction([&]() {
boost::optional<Entity> e2 = entityRepo.findById(9);
assert(e2);
e2->balance += 1000;
entityRepo.save(*e2);
throw AbortTransaction();
});
});
std::cout << "[+] Nested transaction, inner aborts: Finished" << std::endl;
}
int main(int argc, char** argv)
{
const std::string dbFile = "poc.db";
ConcreteConnectionManager connectionManager(dbFile);
ConcreteTransactionManager transactionManager(connectionManager);
ConcreteEntityRepo entityRepo(transactionManager);
{
ScopedDbConnection conn = connectionManager.getConnection();
const char* dropTableSql = "DROP TABLE IF EXISTS entities";
sqlite3_exec(conn.get(), dropTableSql, nullptr, nullptr, nullptr);
const char* createTableSql = "CREATE TABLE IF NOT EXISTS entities "
"(id INT PRIMARY KEY, balance INT)";
sqlite3_exec(conn.get(), createTableSql, nullptr, nullptr, nullptr);
const char* walModePragma = "PRAGMA journal_mode=WAL";
sqlite3_exec(conn.get(), walModePragma, nullptr, nullptr, nullptr);
std::cout << "[+] DB Created" << std::endl;
}
// No transaction
noTransaction(entityRepo);
inTransactionNoOneElse(transactionManager, entityRepo);
transactionAborted(transactionManager, entityRepo);
simultaneousTransactions(transactionManager, entityRepo);
simultaneousConflictingTransactions(transactionManager, entityRepo);
simultaneousConflictingTransactions2(transactionManager, entityRepo);
nestedTransaction(transactionManager, entityRepo);
nestedTransactionOuterAbort(transactionManager, entityRepo);
nestedTransactionInnerAbort(transactionManager, entityRepo);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment