Skip to content

Instantly share code, notes, and snippets.

@lemon24
Last active October 22, 2023 09:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lemon24/64704ced769c5723a75ad64b5d023883 to your computer and use it in GitHub Desktop.
Save lemon24/64704ced769c5723a75ad64b5d023883 to your computer and use it in GitHub Desktop.
Distributed key-value store prototype, with no kind of consistency.
"""
Distributed key-value store prototype, with no kind of consistency.
---
A demo (from before we had bootstapping): https://asciinema.org/a/616231
On one machine:
>>> d = dkv.DKV()
>>> d.start()
>>>
>>> d.get('one', timeout=1) # wait 1s for a response from others
>>> d.set('one', b'111')
>>> d.get('one')
b'111'
On a second machine:
>>> d = dkv.DKV()
>>> d.start()
>>>
>>> d.get('one') # retrieve value from the first machine
b'111'
>>> d.set('two', b'222')
Back on the first machine:
>>> d.get('two', timeout=0) # value already received from the second machine
b'222'
---
The current network architecture is as follows:
* each node has a SUB socket to receive messages from all other nodes
* each node has a PUB socket to send messages to all other nodes
.--------------------------.
| v
.-----------. .-----------.
| PUB | SUB |<-------| PUB | SUB |
'-----------' '-----------'
| ^ .----------' ^
| '---------. |
| v | |
| .-----------. |
'----->| SUB | PUB |-------'
'-----------'
This is wildly inefficient; for example, when A asks for a key,
all its peers respond to all their peers, not only to A.
A has a subscription filter for messages intended for itself,
so this doesn't need to be handled in code,
but the network traffic still happens underneath.
Note the ZeroMQ book already has a [shared key-value store] example,
but I wanted to see if I can cobble together something on my own.
---
Discovery works via IPv4 local network UDP broadcast,
and works even if the nodes move networks / change IPs.
This is needed because while ZeroMQ (known) peers can come and go,
there's no way of discovering them.
This is based on same idea as [zbeacon],
but cobbled together independently from StackOverflow examples
(zbeacon comes from the C binding, and does not exist in the Python one).
Also see the section on [discovery] in the ZeroMQ book.
---
Bootstrapping starts whenever a node gets network connectivity back.
The node requests a list of keys from all other peers,
then requests keys one by one with a small delay.
---
[zbeacon]: http://api.zeromq.org/czmq1-4:zbeacon
[discovery]: https://zguide.zeromq.org/docs/chapter8/#Discovery
[shared key-value store]: https://zguide.zeromq.org/docs/chapter5/#Reliable-Pub-Sub-Clone-Pattern
"""
import threading
import random
import socket
import queue
import errno
import time
import zmq
import sys
from dataclasses import dataclass, field
from functools import wraps, partial
ERRNO_NET = {errno.ENETUNREACH, errno.EADDRNOTAVAIL, errno.ENETDOWN}
class Beacon:
"""IPv4 local network broadcast beacon.
Send payload to local network nodes every second.
Call event((address, payload)) with messages from other nodes.
"""
interval = 1
prefix = b'beacon '
port = 5005
def __init__(self, payload, event):
self.payload = payload
self.event = event
self.address = None
self.sender = threading.Thread(target=self._send, daemon=True)
self.receiver = threading.Thread(target=self._receive, daemon=True)
self.done = False
def start(self):
self.sender.start()
self.receiver.start()
def shutdown(self):
self.done = True
self.sender.join()
self.receiver.join()
# https://stackoverflow.com/a/64067297 + comment by Tamir Adler
def _send(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
message = self.prefix + self.payload
while not self.done:
try:
s.sendto(message, ('255.255.255.255', self.port))
except OSError as e:
if e.errno not in ERRNO_NET:
raise
if self.address:
self.event(('net_down', e))
self.address = None
time.sleep(self.interval)
def _receive(self):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
# only needed when running more than one per host
# https://gist.github.com/Crtrpt/616eaae1ec00810c1d04474f188bcebd
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
s.bind(('0.0.0.0', self.port))
s.settimeout(self.interval)
while not self.done:
if not self.address:
try:
self.address = address = get_local_address()
self.event(('net_up', address))
except OSError as e:
if e.errno not in ERRNO_NET:
raise
try:
data, (address, _) = s.recvfrom(1024)
except TimeoutError:
continue
if not self.address:
continue
if not data.startswith(self.prefix):
continue
payload = data.removeprefix(self.prefix)
if address == self.address and payload == self.payload:
continue
self.event(('ping', address, payload))
def get_local_address():
# https://stackoverflow.com/a/166589
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
class Tracker:
"""IPv4 local network presence tracker.
Broadcast own presence using a Beacon.
Call event((True, peer)) on the first heartbeat from a peer.
Call event((False, peer)) after timeout seconds without a heartbeat.
"""
timeout = 4
def __init__(self, payload, event):
self.event = event
self.peers = {}
self.queue = queue.Queue()
self.beacon = Beacon(payload, self.queue.put)
self.worker = threading.Thread(target=self._worker, daemon=True)
self.done = False
def start(self):
self.beacon.start()
self.worker.start()
def shutdown(self):
self.done = True
self.beacon.shutdown()
self.worker.join()
def _worker(self):
while not self.done:
now = time.monotonic()
for peer, last_seen in list(self.peers.items()):
if now - last_seen > self.timeout:
self.event(('peer_down', *peer))
del self.peers[peer]
try:
event = self.queue.get(timeout=self.beacon.interval)
except queue.Empty:
continue
if event[0] != 'ping':
self.event(event)
continue
peer = event[1:]
if peer not in self.peers:
self.event(('peer_up', *peer))
self.peers[peer] = now
for peer in self.peers:
self.event(('peer_down', *peer))
class DKV:
def __init__(self, log=None):
self.log = log or (lambda *_: None)
self.data = {}
self.worker = threading.Thread(target=self._worker, daemon=True)
self.done = False
# https://pyzmq.readthedocs.io/en/latest/howto/morethanbindings.html#thread-safety
# after __init__, we either lock around socket method calls,
# or only use a socket from a single thread
self.ctx = ctx = zmq.Context()
# socket for sending messages to other peers;
# used from anywhere, so it needs a lock
self.pub = ctx.socket(zmq.PUB)
self.port = self.pub.bind_to_random_port('tcp://*')
self.pub_lock = threading.Lock()
# broadcast pub's port to other nodes;
# send presence changes from others to an internal socket
sender = ctx.socket(zmq.PAIR)
sender.bind("inproc://events")
self.tracker = Tracker(str(self.port).encode(), sender.send_pyobj)
# ...so we can receive them in the worker thread
self.tracker_events = ctx.socket(zmq.PAIR)
self.tracker_events.connect("inproc://events")
# socket for receiving messages from other peers;
# only used from the worker thread
self.sub = ctx.socket(zmq.SUB)
# we care about value updates and questions from anyone
self.sub.subscribe(b'set')
self.sub.subscribe(b'question')
# but only about answers intended for us
self.id = random.randbytes(8)
self.sub.subscribe(b'answer ' + self.id)
# one threading.Event for each key we've asked about,
# so get() can wait for an answer
self.pending_questions = {}
self.bootstrap = None
self.sub.subscribe(b'list')
self.sub.subscribe(b'list_answer ' + self.id)
def start(self):
self.tracker.start()
self.worker.start()
def shutdown(self):
self.done = True
self.tracker.shutdown()
self.worker.join()
self.ctx.destroy()
def _worker(self):
self.log(f"publishing on tcp://*:{self.port} with id {self.id.hex()}")
poller = zmq.Poller()
poller.register(self.sub, zmq.POLLIN)
poller.register(self.tracker_events, zmq.POLLIN)
while not self.done:
for sock, _ in poller.poll(100):
if sock is self.tracker_events:
name, *args = sock.recv_pyobj()
self.log("tracker:", name, *args)
elif sock is self.sub:
name, *args = sock.recv_multipart()
self.log("sub received:", name, *args)
name, _, id = name.partition(b' ')
name = name.decode()
if id:
args = (id, *args)
else:
assert False
try:
meth = getattr(self, f'_handle_{name}')
except AttributeError:
self.log("UNHANDLED!", name)
else:
meth(*args)
if self.bootstrap:
if self.bootstrap.done:
self.bootstrap = None
else:
self.bootstrap.step()
def _handle_peer_up(self, ip, port_bytes):
address = f"tcp://{ip}:{port_bytes.decode()}"
self.sub.connect(address)
def _handle_peer_down(self, ip, port_bytes):
address = f"tcp://{ip}:{port_bytes.decode()}"
self.sub.disconnect(address)
def _handle_set(self, key, value):
self.data[key.decode()] = value
def _request_question(self, key):
self._pub_send_multipart((b'question ' + self.id, key.encode()))
def _handle_question(self, id, key):
value = self.data.get(key.decode())
if not value:
return
self._pub_send_multipart((b'answer ' + id, key, value))
def _handle_answer(self, id, key, value):
assert id == self.id
self.data[key.decode()] = value
# notify any waiting get() calls that we have an answer
have_answer = self.pending_questions.pop(key.decode(), None)
if have_answer:
have_answer.set()
def _handle_net_up(self, ip):
self.bootstrap = BootstrapState(
self._request_list,
self._request_question,
# give peers time to connect
self.tracker.beacon.interval * 2,
log=self.log,
)
def _request_list(self):
self._pub_send_multipart((b'list ' + self.id,))
def _handle_list(self, id):
message = list(map(str.encode, self.data))
message.insert(0, b'list_answer ' + id)
self._pub_send_multipart(message)
def _handle_list_answer(self, id, *keys):
assert id == self.id
if not self.bootstrap:
return
self.bootstrap.handle_list(map(bytes.decode, keys))
def _pub_send_multipart(self, message):
with self.pub_lock:
self.pub.send_multipart(message)
self.log('pub sent:', *message)
def set(self, key, value):
self.data[key] = value
# tell everyone about the new value
self._pub_send_multipart((b'set', key.encode(), value))
def get(self, key, *, timeout=.1):
# if we have it, return it
value = self.data.get(key)
if value is not None:
return value
# if we don't have it, ask others,
# and wait `timeout` seconds for an answer
# setdefault() is likely atomic
# https://bugs.python.org/issue13521
# https://mail.python.org/pipermail/python-list/2018-July/885957.html
have_answer = self.pending_questions.setdefault(key, threading.Event())
self._request_question(key)
if have_answer.wait(timeout):
self.log(f"get({key!r}): got answer")
else:
self.log(f"get({key!r}): timed out")
# if someone answered, the value is already in self.data
return self.data.get(key)
@dataclass
class BootstrapState:
request_list: callable
request_key: callable
list_delay: float
key_delay: float = 0.005
list_after: float = 0
key_after: float = 0
done_after: float = 0
done: bool = False
remaining_keys: set = field(default_factory=set)
done_keys: set = field(default_factory=set)
log: callable = lambda *_: None
time = time.monotonic
def __post_init__(self):
self.list_after = self.time() + self.list_delay
self.log(f"bootstrap: waiting {self.list_delay:.1f}s before starting")
def step(self):
now = self.time()
if self.list_after and self.list_after <= now:
self.request_list()
self.list_after = 0
self.key_after = now + self.key_delay
self.log(f"bootstrap: started")
elif self.key_after and self.key_after <= now:
if self.remaining_keys:
key = self.remaining_keys.pop()
self.request_key(key)
self.done_keys.add(key)
if self.remaining_keys:
self.key_after = now + self.key_delay
else:
self.key_after = 0
# should use a separate delay for this, but eh
self.done_after = now + self.list_delay
self.log(f"bootstrap: no keys remaining, waiting another {self.list_delay:.1f}s")
elif self.done_after and self.done_after <= now:
self.done = True
self.log(f"bootstrap: done")
def handle_list(self, keys):
self.remaining_keys.update(k for k in keys if k not in self.done_keys)
self.key_after = self.time() + self.key_delay
if __name__ == '__main__':
import random
dkv = DKV(print)
dkv.start()
dkv.set(f'one-{dkv.port}', b'111')
time.sleep(1 + random.random())
if random.random() < 0.5:
dkv.set('hello', f"from {dkv.port}".encode())
time.sleep(random.random())
print(time.time())
print(dkv.get('hello'))
print(time.time())
time.sleep(100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment