Created
July 22, 2018 08:53
-
-
Save nicois/4079dec03d33a24f236ecaaf18ffecf9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from os import environ | |
import argparse | |
import json | |
from functools import partial | |
import secrets | |
import logging | |
import asyncio | |
from psycopg2.extras import Json | |
import websockets | |
import aiopg | |
import aioredis | |
from raven import Client | |
logger = logging.getLogger(__file__) | |
client = Client( | |
"http://8362ae94624245f4bb5c64e146778395:114cae74aa79458abfd5e73ea3cbaacd@sentry.genesiscare.com.au/18" | |
) | |
class Store: | |
pool = None | |
dsn = None | |
def __init__(self, *, host, password, user, dbname): | |
self._host = host | |
self._password = password | |
self._user = user | |
self._dbname = dbname | |
# This maps stream names to their pkeys | |
self._stream_map = dict() | |
async def async_init(self, setup_database=True): | |
dsn = " ".join( | |
[ | |
f"dbname={self._dbname}", | |
f"user={self._user}", | |
f"password={self._password}", | |
f"host={self._host}", | |
] | |
) | |
self.pool = await aiopg.create_pool(dsn=dsn, enable_hstore=False) | |
if setup_database: | |
await self.setup_database() | |
await self.map_stream_ids() | |
async def setup_database(self): | |
""" | |
Initialise the tables etc used by the cache | |
""" | |
# FIXME: dynamically identify the statements to index on based on the filters(!) | |
statements = [ | |
"""CREATE TABLE IF NOT EXISTS streams ( | |
id SERIAL PRIMARY KEY, | |
name VARCHAR UNIQUE NOT NULL, | |
latest_id VARCHAR DEFAULT '1-1'); | |
""", | |
"""CREATE TABLE IF NOT EXISTS messages ( | |
id SERIAL PRIMARY KEY, | |
stream INTEGER REFERENCES streams(id) NOT NULL, | |
pkey VARCHAR NOT NULL, | |
content JSONB NOT NULL, | |
CONSTRAINT one_key_per_stream UNIQUE (stream, pkey) | |
); | |
""", | |
""" | |
CREATE INDEX CONCURRENTLY IF NOT EXISTS schId ON messages | |
(( content#>'{payload,Sch_ID}' ), stream); | |
""", | |
""" | |
CREATE INDEX CONCURRENTLY IF NOT EXISTS patId1 ON messages | |
(( content#>'{payload,Pat_ID1}' ), stream); | |
""", | |
] | |
async with self.pool.acquire() as conn: | |
async with conn.cursor() as cur: | |
for sql in statements: | |
await cur.execute(sql) | |
async def map_stream_ids(self): | |
async with self.pool.acquire() as conn: | |
async with conn.cursor() as cur: | |
sql = "SELECT id, name FROM streams" | |
await cur.execute(sql) | |
async for stream_id, stream_name in cur: | |
self._stream_map[stream_name] = stream_id | |
async def get_streams(self): | |
""" | |
Return a dict whose keys are the streams, and | |
the values are their latest message IDs | |
""" | |
result = dict() | |
async with self.pool.acquire() as conn: | |
async with conn.cursor() as cur: | |
sql = "SELECT latest_id, name FROM streams" | |
await cur.execute(sql) | |
async for latest_id, stream_name in cur: | |
result[stream_name] = latest_id | |
return result | |
async def add_stream(self, stream_name): | |
""" | |
Returns True if we actually needed to add it. | |
""" | |
if stream_name in self._stream_map: | |
# already added | |
return False | |
async with self.pool.acquire() as conn: | |
async with conn.cursor() as cur: | |
sql = "INSERT INTO streams (name) VALUES (%s) ON CONFLICT DO NOTHING" | |
await cur.execute(sql, (stream_name,)) | |
logger.debug(f"Added new stream named {stream_name}.") | |
await self.map_stream_ids() | |
return True | |
async def append_messages(self, *, stream_name, latest_id, messages): | |
if stream_name not in self._stream_map: | |
await self.map_stream_ids() | |
stream_id = self._stream_map[stream_name] | |
async with self.pool.acquire() as conn: | |
async with conn.cursor() as cur: | |
for message in messages: | |
pkey = f"{message['pkey']}:{message['variant']}" | |
sql = """ | |
INSERT INTO messages (stream, pkey, content) VALUES (%s, %s, %s) | |
ON CONFLICT ON CONSTRAINT one_key_per_stream DO UPDATE SET content = EXCLUDED.content | |
;""" | |
await cur.execute(sql, (stream_id, pkey, Json(message))) | |
sql = "UPDATE streams SET latest_id=%s WHERE id=%s" | |
logger.debug(f"Appended {len(messages)} messages to {stream_name}") | |
await cur.execute(sql, (latest_id, stream_id)) | |
class RandomStore(Store): | |
""" | |
Create a randomly-named database. | |
Intended for testing only!! | |
""" | |
def __init__(self, **kw): | |
self.random_db_name = "db_" + secrets.token_hex()[:20] | |
kw["dbname"] = self.random_db_name | |
logger.info(f"Using database name {self.random_db_name}") | |
super().__init__(**kw) | |
async def async_init(self): | |
# Create the pool to template1 | |
self._dbname = "template1" | |
await super().async_init() | |
# Yuck injection, but not of user data so calm down | |
sql = f"CREATE DATABASE {self.random_db_name} TEMPLATE=template0;" | |
async with self.pool.acquire() as conn: | |
async with conn.cursor() as cur: | |
await cur.execute(sql) | |
logger.debug(f"Created temp database {self.random_db_name}.") | |
self._dbname = self.random_db_name | |
await super().async_init() | |
async def destroy_db(self): | |
self._dbname = "template1" | |
await super().async_init() | |
# Yuck injection, but not of user data so calm down | |
sql = f"DROP DATABASE {self.random_db_name};" | |
async with self.pool.acquire() as conn: | |
async with conn.cursor() as cur: | |
await cur.execute(sql) | |
logger.debug(f"Destroyed temp database {self.random_db_name}.") | |
class StreamFilter: | |
_store_query = None | |
def __init__(self, *, stream_name, store, data): | |
self.stream_name = stream_name | |
self.store = store | |
self.filter_commands = data["filter"] or [] | |
self._build_store_query() | |
def _build_store_query(self): | |
""" | |
Construct the SQL to retrieve the matching messages from | |
the database. Silently skip any filters which can't be | |
applied at this level; assume they will be applied | |
later in python land. | |
Supported syntax: | |
("foo", "is", "42") # exact match | |
("foo->>", "json", 42) # jsonb match | |
(https://www.postgresql.org/docs/9.4/static/functions-json.html) | |
""" | |
stream_id = self.store._stream_map[self.stream_name] | |
filter_expressions = [f"stream={stream_id}"] | |
for arg1, operator, arg2 in self.filter_commands: | |
if arg1 == "python": | |
continue | |
# FIXME: quote pg args using psycopg2 or similar | |
if operator == "is": | |
filter_expressions.append(f"content->>'{arg1}' = '{arg2}'") | |
elif operator == "sql": | |
filter_expressions.append(f"{arg1} {arg2}") | |
else: | |
assert False, f"unexpected operator {operator}" | |
filter_expression = " AND ".join(filter_expressions) | |
self._store_query = f"SELECT content FROM messages WHERE {filter_expression}" | |
def from_stream(self, messages): | |
for message in messages: | |
if self._check_message(message=message): | |
yield message | |
def _check_message(self, message): | |
""" | |
Perform python-level filtering. Ignore any | |
filters which python doesn't understand. | |
""" | |
for arg1, arg2, arg3 in self.filter_commands: | |
if arg1 == "python": | |
fn = eval(arg2) | |
if arg3: | |
fn = partial(fn, **arg3) | |
result = fn(message) | |
if not result: | |
return False | |
elif arg2 == "is": | |
if message.get(arg1) != arg3: | |
return False | |
return True | |
async def from_store(self, limit=None): | |
""" | |
Retrieve all messages matching the filters from the store. | |
""" | |
sql = self._store_query | |
if limit is not None: | |
sql = f"{sql} ORDER BY id DESC LIMIT {limit}" | |
logger.debug(f"Using SQL: {sql}") | |
async with self.store.pool.acquire() as conn: | |
async with conn.cursor() as cur: | |
await cur.execute(sql) | |
async for content, in cur: | |
if self._check_message(message=content): | |
yield content | |
class App: | |
__streams = None # used by get_streams to cache their names | |
def __init__(self, store): | |
self.subscriptions = defaultdict( | |
dict | |
) # key is stream name, value is dict of websockets (users) to StreamFilters | |
# This is needed for streaming messages which arrive while the client | |
# is still digesting the initial set of data | |
self.message_buffer = ( | |
dict() | |
) # key is websocket, value is dict of stream names to pending messages | |
self.store = store | |
async def monitor(self): | |
pass | |
async def async_init(self): | |
await self.store.async_init() | |
async def register(self, websocket): | |
logger.info(f"Registering a new client from {websocket.origin}") | |
self.message_buffer[websocket] = dict() | |
logger.info(f"Total number of clients: {len(self.message_buffer)}") | |
async def unregister(self, websocket): | |
logger.info(f"Unregistering a client from {websocket.origin}") | |
for stream_name in self.message_buffer[websocket]: | |
del self.subscriptions[stream_name][websocket] | |
del self.message_buffer[websocket] | |
logger.info(f"Total number of clients: {len(self.message_buffer)}") | |
async def send_messages(self, *, websocket, stream_name, messages, **extra): | |
if len(messages) > 0: | |
await websocket.send( | |
json.dumps( | |
dict( | |
type="messages", | |
stream_name=stream_name, | |
messages=messages, | |
**extra, | |
) | |
) | |
) | |
async def request_resend(self, *, websocket, data): | |
""" | |
Do a one-off send from the cache of the records in the stream | |
matching this filter. | |
""" | |
stream_name = data.pop("stream_name") | |
stream_filter = StreamFilter( | |
data=data, stream_name=stream_name, store=self.store | |
) | |
messages = [message async for message in stream_filter.from_store()] | |
await self.send_messages( | |
websocket=websocket, stream_name=stream_name, messages=messages, source=data | |
) | |
async def subscribe_stream(self, *, websocket, data): | |
stream_name = data.pop("stream_name") | |
sync_previous = data.pop("sync_previous", True) | |
assert isinstance(sync_previous, (int, bool)) | |
actually_added = await self.store.add_stream(stream_name=stream_name) | |
if actually_added: | |
await self.get_streams(force_refetch=True) | |
stream_filter = StreamFilter( | |
data=data, stream_name=stream_name, store=self.store | |
) | |
assert ( | |
websocket not in self.subscriptions[stream_name] | |
), "Stream is already subscribed to" | |
self.subscriptions[stream_name][websocket] = stream_filter | |
logger.info(f"Subscribed to {stream_name} from {websocket.origin}") | |
if sync_previous: | |
self.message_buffer[websocket][stream_name] = list() | |
# retrieve existing messages in the DB via the streamfilter | |
# and send them to the client | |
limit = None if isinstance(sync_previous, bool) else sync_previous | |
messages = [ | |
message async for message in stream_filter.from_store(limit=limit) | |
] | |
await self.send_messages( | |
websocket=websocket, stream_name=stream_name, messages=messages | |
) | |
# Now the database's data has been sent, transmit whatever | |
# was buffered in the interim | |
mb = self.message_buffer[websocket][stream_name] | |
while len(mb) > 0: | |
# Empty out the buffer | |
self.message_buffer[websocket][stream_name] = list() | |
await self.send_messages( | |
websocket=websocket, stream_name=stream_name, messages=mb | |
) | |
# OK we have actually sent everything pending. We can signal that from now on | |
# the messages can be sent immediately for this stream | |
self.message_buffer[websocket][stream_name] = None | |
async def unsubscribe_stream(self, *, websocket, data): | |
stream_name = data.pop("stream_name") | |
logger.info(f"Unsubscribed to {stream_name} from {websocket.origin}") | |
# silently ignore scenarios where the stream wasn't already subscribed to | |
self.subscriptions[stream_name].pop(websocket) | |
self.message_buffer[websocket].pop(stream_name) | |
async def get_streams(self, force_refetch=False): | |
if self.__streams is None or force_refetch: | |
logger.info("Forcing a refresh of the stream names") | |
self.__streams = await self.store.get_streams() | |
logger.debug(self.__streams) | |
return self.__streams | |
async def append_to_stream(self, *, stream_name, latest_id, messages): | |
tasks = [ | |
self.store.append_messages( | |
stream_name=stream_name, latest_id=latest_id, messages=messages | |
) | |
] | |
# Publish these messages to all subscribers | |
# (or add them to their buffer if they aren't quite ready yet) | |
tasks.extend( | |
[ | |
self.process_messages_for_user( | |
stream_name=stream_name, messages=messages, recipient=user | |
) | |
for user in self.subscriptions[stream_name] | |
] | |
) | |
# Update the DB and send the data to the clients in parallel | |
await asyncio.wait(tasks) | |
async def process_messages_for_user(self, *, stream_name, messages, recipient): | |
""" | |
Either push the messages straight out, or insert them in the buffer. | |
Whatever the case, apply their streamfilter first. | |
Note: to avoid race conditions, you need to make sure | |
to not call await between checking the buffer status | |
and operating on it. Otherwise it's possible a person | |
could miss some messages, if they are just finishing | |
processing their buffer. | |
""" | |
sf = self.subscriptions[stream_name].get(recipient) | |
buffer = self.message_buffer.get(recipient) | |
if sf is None or buffer is None or stream_name not in buffer: | |
# They must have just disconnected/unsubscribed | |
return | |
filtered_messages = list(sf.from_stream(messages=messages)) | |
if buffer[stream_name] is None: | |
# no buffer, just push straight out | |
await self.send_messages( | |
websocket=recipient, messages=filtered_messages, stream_name=stream_name | |
) | |
else: | |
buffer[stream_name].extend(filtered_messages) | |
async def auth(self, *, websocket, data): | |
# TODO | |
return True | |
async def pong(self, *, websocket, data): | |
await websocket.send(json.dumps(dict(type="pong", **data))) | |
async def entrypoint(websocket, path, app): | |
# register(websocket) sends user_event() to websocket | |
await app.register(websocket) | |
action_map = dict( | |
auth=app.auth, | |
subscribe_stream=app.subscribe_stream, | |
request_resend=app.request_resend, | |
pong=app.pong, | |
unsubscribe_stream=app.unsubscribe_stream, | |
) | |
message = "?" | |
try: | |
# await websocket.send(state_event()) | |
async for message in websocket: | |
data = json.loads(message) | |
action = data.get("action") | |
if action is None: | |
logger.error(f"No action in {data}") | |
continue | |
if action in action_map: | |
await action_map[action](websocket=websocket, data=data) | |
else: | |
logger.error(f"Unknown action: {action}") | |
except KeyboardInterrupt: | |
raise | |
except Exception: | |
client.captureException() | |
logger.exception(f"During processing of {message}") | |
finally: | |
await app.unregister(websocket) | |
async def redis_relay(*, app): | |
""" | |
Monitors the redis streams for new data. | |
For now, you'll have to restart the app if you add new streams to the DB. | |
""" | |
loop = asyncio.get_event_loop() | |
redis = await aioredis.create_redis_pool( | |
"redis://test.localhost", # fixme: config | |
minsize=5, | |
maxsize=10, | |
loop=loop, | |
db=0, # fixme: should be config variable | |
) | |
while True: | |
try: | |
streams = await app.get_streams() | |
if streams: | |
new_messages = defaultdict(list) | |
stream_names, latest_ids = zip(*streams.items()) | |
with await redis as conn: | |
for stream_name, latest_id, content in await conn.xread( | |
list(stream_names), | |
timeout=1000 * 600, # milliseconds | |
count=10000, | |
latest_ids=list(latest_ids), | |
): | |
logger.debug(f"Appending to {stream_name}: {content}") | |
try: | |
stream_name = stream_name.decode("utf-8") | |
new_messages[stream_name].append( | |
unpack_message_content(content) | |
) | |
except Exception: | |
client.captureException() | |
logger.exception(f"Cannot process a message:\n{content}") | |
streams[stream_name] = latest_id.decode("utf-8") | |
for stream_name, messages in new_messages.items(): | |
await app.append_to_stream( | |
stream_name=stream_name, | |
latest_id=streams[stream_name], | |
messages=messages, | |
) | |
else: | |
logger.warning( | |
"No streams found. Sleeping for a bit, then I'll try again." | |
) | |
await asyncio.sleep(60) | |
except Exception: | |
client.captureException() | |
logger.exception("oopsy") | |
raise | |
def unpack_message_content(content): | |
result = dict() | |
for k, v in content.items(): | |
if k == b"payload": | |
v = json.loads(v) | |
else: | |
v = v.decode("utf-8") | |
result[k.decode("utf-8")] = v | |
return result | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--port", type=int) | |
parser.add_argument("--randomstore", action="store_true") | |
parser.add_argument("--host", default="postgresql") | |
parser.add_argument("--user", default="postgres") | |
parser.add_argument("--password", default="secret") | |
parser.add_argument("--dbname", default="postgres") | |
args = parser.parse_args() | |
assert args.port > 1024 | |
# TODO: read these from the commandline | |
kw = dict( | |
host=args.host, user=args.user, password=args.password, dbname=args.dbname | |
) | |
if "POSTGRES_PASSWORD" in environ: | |
kw["password"] = environ["POSTGRES_PASSWORD"] | |
store = RandomStore(**kw) if args.randomstore else Store(**kw) | |
app = App(store=store) | |
bound_entrypoint = partial(entrypoint, app=app) | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(app.async_init()) | |
try: | |
tasks = [ | |
websockets.serve(bound_entrypoint, "0.0.0.0", args.port), | |
redis_relay(app=app), | |
app.monitor(), | |
] | |
loop.run_until_complete(asyncio.wait(tasks)) | |
finally: | |
pass | |
# if args.randomstore: | |
# loop.run_until_complete(store.destroy_db) | |
logger.info(f"Listening on port {args.port}") | |
if __name__ == "__main__": | |
logging.basicConfig( | |
format="%(asctime)s %(levelname)s {%(module)s} [%(funcName)s] %(message)s", | |
datefmt="%Y-%m-%d,%H:%M:%S", | |
level=logging.INFO, | |
) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment