Proof of concept of how to design the architecture to handle transactions in C++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <atomic> | |
#include <functional> | |
#include <iostream> | |
#include <map> | |
#include <memory> | |
#include <mutex> | |
#include <thread> | |
#include <boost/optional.hpp> | |
#include <sqlite3.h> | |
// Interfaces | |
using ScopedDbConnection = std::unique_ptr<sqlite3, std::function<void(sqlite3*)>>; | |
class DbConnectionManager { | |
public: | |
virtual ScopedDbConnection getConnection() = 0; | |
}; | |
class AbortTransaction {}; | |
class TransactionAborted {}; | |
class TransactionManager { | |
public: | |
virtual void performInTransaction(const std::function<void()>& f) = 0; | |
}; | |
struct Entity; | |
class EntityRepoException {}; | |
class EntityRepo { | |
public: | |
virtual void save(const Entity& entity) = 0; | |
virtual boost::optional<Entity> findById(uint64_t id) = 0; | |
}; | |
// Implementations | |
class ConcreteConnectionManager : public DbConnectionManager { | |
public: | |
ConcreteConnectionManager(const std::string& dbFile); | |
ScopedDbConnection getConnection() override; | |
private: | |
std::string _dbFile; | |
}; | |
ConcreteConnectionManager::ConcreteConnectionManager(const std::string& dbFile) : | |
_dbFile(dbFile) | |
{} | |
ScopedDbConnection ConcreteConnectionManager::getConnection() | |
{ | |
sqlite3* connection; | |
int openResult = sqlite3_open_v2( | |
_dbFile.c_str(), | |
&connection, | |
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX, | |
NULL | |
); | |
if (openResult != SQLITE_OK) { | |
std::cerr << "[-] Could not open database: " << openResult << std::endl; | |
sqlite3_close_v2(connection); | |
throw std::runtime_error("Database could not be opened"); | |
} | |
return ScopedDbConnection(connection, [](sqlite3* c) { sqlite3_close_v2(c); }); | |
} | |
class ConcreteTransactionManager : public DbConnectionManager, public TransactionManager { | |
public: | |
ConcreteTransactionManager(ConcreteConnectionManager& connectionManager); | |
ScopedDbConnection getConnection() override; | |
void performInTransaction(const std::function<void()>& f) override; | |
private: | |
struct TransactionInfo { | |
ScopedDbConnection dbConnection; | |
uint32_t count; | |
}; | |
TransactionInfo& setupTransaction(const std::thread::id& threadId); | |
void abortTransaction(TransactionInfo& transactionInfo); | |
void commitTransaction(TransactionInfo& transactionInfo); | |
ConcreteConnectionManager& _connectionManager; | |
std::map<std::thread::id, TransactionInfo> _currentTransactions; | |
std::mutex _currentTransactionsMutex; | |
}; | |
ConcreteTransactionManager::ConcreteTransactionManager( | |
ConcreteConnectionManager& connectionManager) : | |
_connectionManager(connectionManager) | |
{} | |
ScopedDbConnection ConcreteTransactionManager::getConnection() | |
{ | |
auto threadId = std::this_thread::get_id(); | |
std::lock_guard<std::mutex> _(_currentTransactionsMutex); | |
auto it = _currentTransactions.find(threadId); | |
if (it == _currentTransactions.end()) { | |
return _connectionManager.getConnection(); | |
} | |
return ScopedDbConnection(it->second.dbConnection.get(), [](sqlite3* c) {}); | |
} | |
void ConcreteTransactionManager::abortTransaction(TransactionInfo& transactionInfo) | |
{ | |
int abortResult = sqlite3_exec( | |
transactionInfo.dbConnection.get(), | |
"ROLLBACK TRANSACTION", | |
nullptr, | |
nullptr, | |
nullptr | |
); | |
// If a transaction cannot be aborted, something bad happened. | |
assert(abortResult == SQLITE_OK); | |
auto threadId = std::this_thread::get_id(); | |
{ | |
std::lock_guard<std::mutex> _(_currentTransactionsMutex); | |
_currentTransactions.erase(threadId); | |
} | |
} | |
void ConcreteTransactionManager::commitTransaction(TransactionInfo& transactionInfo) | |
{ | |
int commitResult = sqlite3_exec( | |
transactionInfo.dbConnection.get(), | |
"COMMIT TRANSACTION", | |
nullptr, | |
nullptr, | |
nullptr | |
); | |
auto threadId = std::this_thread::get_id(); | |
if (commitResult == SQLITE_OK) { | |
std::lock_guard<std::mutex> _(_currentTransactionsMutex); | |
_currentTransactions.erase(threadId); | |
return; | |
} | |
abortTransaction(transactionInfo); | |
throw TransactionAborted(); | |
} | |
ConcreteTransactionManager::TransactionInfo& | |
ConcreteTransactionManager::setupTransaction(const std::thread::id& threadId) | |
{ | |
std::lock_guard<std::mutex> _(_currentTransactionsMutex); | |
if (_currentTransactions.count(threadId) == 0) { | |
TransactionInfo transactionInfo{_connectionManager.getConnection(), 1}; | |
int transactionStartResult = sqlite3_exec( | |
transactionInfo.dbConnection.get(), | |
"BEGIN TRANSACTION", | |
nullptr, | |
nullptr, | |
nullptr | |
); | |
if (transactionStartResult != SQLITE_OK) { | |
throw std::runtime_error("Could not start transaction"); | |
} | |
auto it = _currentTransactions.emplace(threadId, std::move(transactionInfo)).first; | |
return it->second; | |
} | |
else { | |
TransactionInfo& transactionInfo = _currentTransactions[threadId]; | |
++transactionInfo.count; | |
return transactionInfo; | |
} | |
} | |
void ConcreteTransactionManager::performInTransaction(const std::function<void()>& f) | |
{ | |
auto threadId = std::this_thread::get_id(); | |
TransactionInfo& transactionInfo = setupTransaction(threadId); | |
try { | |
f(); | |
} | |
catch (AbortTransaction&) { | |
if (--transactionInfo.count > 0) { | |
throw; | |
} | |
abortTransaction(transactionInfo); | |
return; | |
} | |
catch (...) { | |
if (--transactionInfo.count > 0) { | |
throw; | |
} | |
abortTransaction(transactionInfo); | |
throw; | |
} | |
if (--transactionInfo.count > 0) { | |
return; | |
} | |
commitTransaction(transactionInfo); | |
} | |
struct Entity { | |
int64_t id; | |
int64_t balance; | |
}; | |
class ConcreteEntityRepo : public EntityRepo { | |
public: | |
ConcreteEntityRepo(DbConnectionManager& connectionManager); | |
void save(const Entity& entity) override; | |
boost::optional<Entity> findById(uint64_t id) override; | |
private: | |
DbConnectionManager& _connectionManager; | |
}; | |
ConcreteEntityRepo::ConcreteEntityRepo(DbConnectionManager& connectionManager) : | |
_connectionManager(connectionManager) | |
{} | |
void ConcreteEntityRepo::save(const Entity& entity) | |
{ | |
const char* sql = "INSERT OR REPLACE INTO entities (id, balance) VALUES (?, ?)"; | |
ScopedDbConnection dbConnection = _connectionManager.getConnection(); | |
sqlite3_stmt* rawStmt; | |
int stmtCreationResult = sqlite3_prepare_v2( | |
dbConnection.get(), | |
sql, | |
-1, | |
&rawStmt, | |
nullptr | |
); | |
if (stmtCreationResult != SQLITE_OK) { | |
throw EntityRepoException(); | |
} | |
auto stmtDeleter = [](sqlite3_stmt* p) { sqlite3_finalize(p); }; | |
std::unique_ptr<sqlite3_stmt, decltype(stmtDeleter)> stmt(rawStmt, stmtDeleter); | |
if (sqlite3_bind_int64(stmt.get(), 1, entity.id) != SQLITE_OK) { | |
throw EntityRepoException(); | |
} | |
if (sqlite3_bind_int64(stmt.get(), 2, entity.balance) != SQLITE_OK) { | |
throw EntityRepoException(); | |
} | |
if (sqlite3_step(stmt.get()) != SQLITE_DONE) { | |
throw EntityRepoException(); | |
} | |
} | |
boost::optional<Entity> ConcreteEntityRepo::findById(uint64_t id) | |
{ | |
const char* sql = "SELECT id, balance FROM entities WHERE id = ?"; | |
ScopedDbConnection dbConnection = _connectionManager.getConnection(); | |
sqlite3_stmt* rawStmt; | |
int stmtCreationResult = sqlite3_prepare_v2( | |
dbConnection.get(), | |
sql, | |
-1, | |
&rawStmt, | |
nullptr | |
); | |
if (stmtCreationResult != SQLITE_OK) { | |
throw EntityRepoException(); | |
} | |
auto stmtDeleter = [](sqlite3_stmt* p) { sqlite3_finalize(p); }; | |
std::unique_ptr<sqlite3_stmt, decltype(stmtDeleter)> stmt(rawStmt, stmtDeleter); | |
if (sqlite3_bind_int64(stmt.get(), 1, id) != SQLITE_OK) { | |
throw EntityRepoException(); | |
} | |
int stepResult = sqlite3_step(stmt.get()); | |
if (stepResult == SQLITE_ROW) { | |
Entity result{sqlite3_column_int64(stmt.get(), 0), sqlite3_column_int64(stmt.get(), 1)}; | |
return result; | |
} | |
else if (stepResult == SQLITE_DONE) { | |
return boost::none; | |
} | |
else { | |
throw EntityRepoException(); | |
} | |
} | |
void noTransaction(EntityRepo& repo) | |
{ | |
Entity e{1, 1000}; | |
repo.save(e); | |
std::cout << "[+] No transaction: finished" << std::endl; | |
} | |
void inTransactionNoOneElse(TransactionManager& transactionManager, EntityRepo& entityRepo) | |
{ | |
transactionManager.performInTransaction([&]() { | |
Entity e{2, 1000}; | |
entityRepo.save(e); | |
Entity e2 = *entityRepo.findById(2); | |
e2.balance += 1000; | |
entityRepo.save(e2); | |
}); | |
std::cout << "[+] In transaction with no one else" << std::endl; | |
} | |
void transactionAborted(TransactionManager& transactionManager, EntityRepo& entityRepo) | |
{ | |
transactionManager.performInTransaction([&]() { | |
Entity e{3, 1000}; | |
entityRepo.save(e); | |
Entity e2 = *entityRepo.findById(3); | |
e2.balance += 1000; | |
entityRepo.save(e2); | |
throw AbortTransaction(); | |
}); | |
std::cout << "[+] Abort transaction example: Finished" << std::endl; | |
} | |
void simultaneousTransactions(TransactionManager& transactionManager, EntityRepo& entityRepo) | |
{ | |
std::atomic<bool> saveExecuted{false}; | |
std::atomic<bool> findExecuted{false}; | |
std::thread t1([&]() { | |
transactionManager.performInTransaction([&]() { | |
while (!saveExecuted) {} | |
boost::optional<Entity> e = entityRepo.findById(4); | |
findExecuted = true; | |
if (e) { | |
std::cout << "[-] Transaction behavior violated!" << std::endl; | |
} | |
else { | |
std::cout << "[+] Simultaneous non-interfering transactions: success" << std::endl; | |
} | |
}); | |
}); | |
std::thread t2([&]() { | |
transactionManager.performInTransaction([&]() { | |
Entity e{4, 1000}; | |
entityRepo.save(e); | |
saveExecuted = true; | |
while (!findExecuted) {} | |
}); | |
}); | |
t1.join(); | |
t2.join(); | |
std::cout << "[+] Simultaneous transaction example: Finished" << std::endl; | |
} | |
void simultaneousConflictingTransactions( | |
TransactionManager& transactionManager, | |
EntityRepo& entityRepo) | |
{ | |
std::atomic<bool> saveExecuted{false}; | |
std::atomic<bool> findExecuted{false}; | |
std::thread t1([&]() { | |
transactionManager.performInTransaction([&]() { | |
while (!saveExecuted) {} | |
boost::optional<Entity> e = entityRepo.findById(5); | |
if (e) { | |
std::cout << "[-] Transaction behavior violated!" << std::endl; | |
} | |
else { | |
try { | |
entityRepo.save(Entity{5, 2000}); | |
} | |
catch (EntityRepoException&) { | |
findExecuted = true; | |
std::cout << "[+] Conflicting operations: correct" << std::endl; | |
throw AbortTransaction(); | |
} | |
} | |
}); | |
}); | |
std::thread t2([&]() { | |
transactionManager.performInTransaction([&]() { | |
Entity e{5, 1000}; | |
entityRepo.save(e); | |
saveExecuted = true; | |
while (!findExecuted) {} | |
}); | |
}); | |
t1.join(); | |
t2.join(); | |
std::cout << "[+] Simultaneous transaction example: Finished" << std::endl; | |
} | |
void simultaneousConflictingTransactions2( | |
TransactionManager& transactionManager, | |
EntityRepo& entityRepo) | |
{ | |
std::atomic<bool> saveExecuted{false}; | |
std::atomic<bool> findExecuted{false}; | |
std::atomic<bool> transactionCommitted{false}; | |
std::thread t1([&]() { | |
transactionManager.performInTransaction([&]() { | |
while (!saveExecuted) {} | |
boost::optional<Entity> e = entityRepo.findById(6); | |
findExecuted = true; | |
if (e) { | |
std::cout << "[-] Transaction behavior violated!" << std::endl; | |
throw AbortTransaction(); | |
} | |
while (!transactionCommitted) {} | |
try { | |
entityRepo.save(Entity{6, 2000}); | |
} | |
catch (EntityRepoException&) { | |
std::cout << "[+] Conflicting operations: correct" << std::endl; | |
throw AbortTransaction(); | |
} | |
}); | |
}); | |
std::thread t2([&]() { | |
transactionManager.performInTransaction([&]() { | |
Entity e{6, 1000}; | |
entityRepo.save(e); | |
saveExecuted = true; | |
while (!findExecuted) {} | |
}); | |
transactionCommitted = true; | |
}); | |
t1.join(); | |
t2.join(); | |
std::cout << "[+] Simultaneous transaction example: Finished" << std::endl; | |
} | |
void nestedTransaction(TransactionManager& transactionManager, EntityRepo& entityRepo) | |
{ | |
transactionManager.performInTransaction([&]() { | |
Entity e{7, 1000}; | |
entityRepo.save(e); | |
transactionManager.performInTransaction([&]() { | |
boost::optional<Entity> e2 = entityRepo.findById(7); | |
assert(e2); | |
e2->balance += 1000; | |
entityRepo.save(*e2); | |
}); | |
}); | |
std::cout << "[+] Nested transaction: Finished" << std::endl; | |
} | |
void nestedTransactionOuterAbort(TransactionManager& transactionManager, EntityRepo& entityRepo) | |
{ | |
transactionManager.performInTransaction([&]() { | |
Entity e{8, 1000}; | |
entityRepo.save(e); | |
transactionManager.performInTransaction([&]() { | |
boost::optional<Entity> e2 = entityRepo.findById(8); | |
assert(e2); | |
e2->balance += 1000; | |
entityRepo.save(*e2); | |
}); | |
throw AbortTransaction(); | |
}); | |
std::cout << "[+] Nested transaction, outer aborts: Finished" << std::endl; | |
} | |
void nestedTransactionInnerAbort(TransactionManager& transactionManager, EntityRepo& entityRepo) | |
{ | |
transactionManager.performInTransaction([&]() { | |
Entity e{9, 1000}; | |
entityRepo.save(e); | |
transactionManager.performInTransaction([&]() { | |
boost::optional<Entity> e2 = entityRepo.findById(9); | |
assert(e2); | |
e2->balance += 1000; | |
entityRepo.save(*e2); | |
throw AbortTransaction(); | |
}); | |
}); | |
std::cout << "[+] Nested transaction, inner aborts: Finished" << std::endl; | |
} | |
int main(int argc, char** argv) | |
{ | |
const std::string dbFile = "poc.db"; | |
ConcreteConnectionManager connectionManager(dbFile); | |
ConcreteTransactionManager transactionManager(connectionManager); | |
ConcreteEntityRepo entityRepo(transactionManager); | |
{ | |
ScopedDbConnection conn = connectionManager.getConnection(); | |
const char* dropTableSql = "DROP TABLE IF EXISTS entities"; | |
sqlite3_exec(conn.get(), dropTableSql, nullptr, nullptr, nullptr); | |
const char* createTableSql = "CREATE TABLE IF NOT EXISTS entities " | |
"(id INT PRIMARY KEY, balance INT)"; | |
sqlite3_exec(conn.get(), createTableSql, nullptr, nullptr, nullptr); | |
const char* walModePragma = "PRAGMA journal_mode=WAL"; | |
sqlite3_exec(conn.get(), walModePragma, nullptr, nullptr, nullptr); | |
std::cout << "[+] DB Created" << std::endl; | |
} | |
// No transaction | |
noTransaction(entityRepo); | |
inTransactionNoOneElse(transactionManager, entityRepo); | |
transactionAborted(transactionManager, entityRepo); | |
simultaneousTransactions(transactionManager, entityRepo); | |
simultaneousConflictingTransactions(transactionManager, entityRepo); | |
simultaneousConflictingTransactions2(transactionManager, entityRepo); | |
nestedTransaction(transactionManager, entityRepo); | |
nestedTransactionOuterAbort(transactionManager, entityRepo); | |
nestedTransactionInnerAbort(transactionManager, entityRepo); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment