Skip to content

Instantly share code, notes, and snippets.

@nicois
Created July 22, 2018 08:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nicois/4079dec03d33a24f236ecaaf18ffecf9 to your computer and use it in GitHub Desktop.
Save nicois/4079dec03d33a24f236ecaaf18ffecf9 to your computer and use it in GitHub Desktop.
from collections import defaultdict
from os import environ
import argparse
import json
from functools import partial
import secrets
import logging
import asyncio
from psycopg2.extras import Json
import websockets
import aiopg
import aioredis
from raven import Client
logger = logging.getLogger(__file__)
client = Client(
"http://8362ae94624245f4bb5c64e146778395:114cae74aa79458abfd5e73ea3cbaacd@sentry.genesiscare.com.au/18"
)
class Store:
pool = None
dsn = None
def __init__(self, *, host, password, user, dbname):
self._host = host
self._password = password
self._user = user
self._dbname = dbname
# This maps stream names to their pkeys
self._stream_map = dict()
async def async_init(self, setup_database=True):
dsn = " ".join(
[
f"dbname={self._dbname}",
f"user={self._user}",
f"password={self._password}",
f"host={self._host}",
]
)
self.pool = await aiopg.create_pool(dsn=dsn, enable_hstore=False)
if setup_database:
await self.setup_database()
await self.map_stream_ids()
async def setup_database(self):
"""
Initialise the tables etc used by the cache
"""
# FIXME: dynamically identify the statements to index on based on the filters(!)
statements = [
"""CREATE TABLE IF NOT EXISTS streams (
id SERIAL PRIMARY KEY,
name VARCHAR UNIQUE NOT NULL,
latest_id VARCHAR DEFAULT '1-1');
""",
"""CREATE TABLE IF NOT EXISTS messages (
id SERIAL PRIMARY KEY,
stream INTEGER REFERENCES streams(id) NOT NULL,
pkey VARCHAR NOT NULL,
content JSONB NOT NULL,
CONSTRAINT one_key_per_stream UNIQUE (stream, pkey)
);
""",
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS schId ON messages
(( content#>'{payload,Sch_ID}' ), stream);
""",
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS patId1 ON messages
(( content#>'{payload,Pat_ID1}' ), stream);
""",
]
async with self.pool.acquire() as conn:
async with conn.cursor() as cur:
for sql in statements:
await cur.execute(sql)
async def map_stream_ids(self):
async with self.pool.acquire() as conn:
async with conn.cursor() as cur:
sql = "SELECT id, name FROM streams"
await cur.execute(sql)
async for stream_id, stream_name in cur:
self._stream_map[stream_name] = stream_id
async def get_streams(self):
"""
Return a dict whose keys are the streams, and
the values are their latest message IDs
"""
result = dict()
async with self.pool.acquire() as conn:
async with conn.cursor() as cur:
sql = "SELECT latest_id, name FROM streams"
await cur.execute(sql)
async for latest_id, stream_name in cur:
result[stream_name] = latest_id
return result
async def add_stream(self, stream_name):
"""
Returns True if we actually needed to add it.
"""
if stream_name in self._stream_map:
# already added
return False
async with self.pool.acquire() as conn:
async with conn.cursor() as cur:
sql = "INSERT INTO streams (name) VALUES (%s) ON CONFLICT DO NOTHING"
await cur.execute(sql, (stream_name,))
logger.debug(f"Added new stream named {stream_name}.")
await self.map_stream_ids()
return True
async def append_messages(self, *, stream_name, latest_id, messages):
if stream_name not in self._stream_map:
await self.map_stream_ids()
stream_id = self._stream_map[stream_name]
async with self.pool.acquire() as conn:
async with conn.cursor() as cur:
for message in messages:
pkey = f"{message['pkey']}:{message['variant']}"
sql = """
INSERT INTO messages (stream, pkey, content) VALUES (%s, %s, %s)
ON CONFLICT ON CONSTRAINT one_key_per_stream DO UPDATE SET content = EXCLUDED.content
;"""
await cur.execute(sql, (stream_id, pkey, Json(message)))
sql = "UPDATE streams SET latest_id=%s WHERE id=%s"
logger.debug(f"Appended {len(messages)} messages to {stream_name}")
await cur.execute(sql, (latest_id, stream_id))
class RandomStore(Store):
"""
Create a randomly-named database.
Intended for testing only!!
"""
def __init__(self, **kw):
self.random_db_name = "db_" + secrets.token_hex()[:20]
kw["dbname"] = self.random_db_name
logger.info(f"Using database name {self.random_db_name}")
super().__init__(**kw)
async def async_init(self):
# Create the pool to template1
self._dbname = "template1"
await super().async_init()
# Yuck injection, but not of user data so calm down
sql = f"CREATE DATABASE {self.random_db_name} TEMPLATE=template0;"
async with self.pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(sql)
logger.debug(f"Created temp database {self.random_db_name}.")
self._dbname = self.random_db_name
await super().async_init()
async def destroy_db(self):
self._dbname = "template1"
await super().async_init()
# Yuck injection, but not of user data so calm down
sql = f"DROP DATABASE {self.random_db_name};"
async with self.pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(sql)
logger.debug(f"Destroyed temp database {self.random_db_name}.")
class StreamFilter:
_store_query = None
def __init__(self, *, stream_name, store, data):
self.stream_name = stream_name
self.store = store
self.filter_commands = data["filter"] or []
self._build_store_query()
def _build_store_query(self):
"""
Construct the SQL to retrieve the matching messages from
the database. Silently skip any filters which can't be
applied at this level; assume they will be applied
later in python land.
Supported syntax:
("foo", "is", "42") # exact match
("foo->>", "json", 42) # jsonb match
(https://www.postgresql.org/docs/9.4/static/functions-json.html)
"""
stream_id = self.store._stream_map[self.stream_name]
filter_expressions = [f"stream={stream_id}"]
for arg1, operator, arg2 in self.filter_commands:
if arg1 == "python":
continue
# FIXME: quote pg args using psycopg2 or similar
if operator == "is":
filter_expressions.append(f"content->>'{arg1}' = '{arg2}'")
elif operator == "sql":
filter_expressions.append(f"{arg1} {arg2}")
else:
assert False, f"unexpected operator {operator}"
filter_expression = " AND ".join(filter_expressions)
self._store_query = f"SELECT content FROM messages WHERE {filter_expression}"
def from_stream(self, messages):
for message in messages:
if self._check_message(message=message):
yield message
def _check_message(self, message):
"""
Perform python-level filtering. Ignore any
filters which python doesn't understand.
"""
for arg1, arg2, arg3 in self.filter_commands:
if arg1 == "python":
fn = eval(arg2)
if arg3:
fn = partial(fn, **arg3)
result = fn(message)
if not result:
return False
elif arg2 == "is":
if message.get(arg1) != arg3:
return False
return True
async def from_store(self, limit=None):
"""
Retrieve all messages matching the filters from the store.
"""
sql = self._store_query
if limit is not None:
sql = f"{sql} ORDER BY id DESC LIMIT {limit}"
logger.debug(f"Using SQL: {sql}")
async with self.store.pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(sql)
async for content, in cur:
if self._check_message(message=content):
yield content
class App:
__streams = None # used by get_streams to cache their names
def __init__(self, store):
self.subscriptions = defaultdict(
dict
) # key is stream name, value is dict of websockets (users) to StreamFilters
# This is needed for streaming messages which arrive while the client
# is still digesting the initial set of data
self.message_buffer = (
dict()
) # key is websocket, value is dict of stream names to pending messages
self.store = store
async def monitor(self):
pass
async def async_init(self):
await self.store.async_init()
async def register(self, websocket):
logger.info(f"Registering a new client from {websocket.origin}")
self.message_buffer[websocket] = dict()
logger.info(f"Total number of clients: {len(self.message_buffer)}")
async def unregister(self, websocket):
logger.info(f"Unregistering a client from {websocket.origin}")
for stream_name in self.message_buffer[websocket]:
del self.subscriptions[stream_name][websocket]
del self.message_buffer[websocket]
logger.info(f"Total number of clients: {len(self.message_buffer)}")
async def send_messages(self, *, websocket, stream_name, messages, **extra):
if len(messages) > 0:
await websocket.send(
json.dumps(
dict(
type="messages",
stream_name=stream_name,
messages=messages,
**extra,
)
)
)
async def request_resend(self, *, websocket, data):
"""
Do a one-off send from the cache of the records in the stream
matching this filter.
"""
stream_name = data.pop("stream_name")
stream_filter = StreamFilter(
data=data, stream_name=stream_name, store=self.store
)
messages = [message async for message in stream_filter.from_store()]
await self.send_messages(
websocket=websocket, stream_name=stream_name, messages=messages, source=data
)
async def subscribe_stream(self, *, websocket, data):
stream_name = data.pop("stream_name")
sync_previous = data.pop("sync_previous", True)
assert isinstance(sync_previous, (int, bool))
actually_added = await self.store.add_stream(stream_name=stream_name)
if actually_added:
await self.get_streams(force_refetch=True)
stream_filter = StreamFilter(
data=data, stream_name=stream_name, store=self.store
)
assert (
websocket not in self.subscriptions[stream_name]
), "Stream is already subscribed to"
self.subscriptions[stream_name][websocket] = stream_filter
logger.info(f"Subscribed to {stream_name} from {websocket.origin}")
if sync_previous:
self.message_buffer[websocket][stream_name] = list()
# retrieve existing messages in the DB via the streamfilter
# and send them to the client
limit = None if isinstance(sync_previous, bool) else sync_previous
messages = [
message async for message in stream_filter.from_store(limit=limit)
]
await self.send_messages(
websocket=websocket, stream_name=stream_name, messages=messages
)
# Now the database's data has been sent, transmit whatever
# was buffered in the interim
mb = self.message_buffer[websocket][stream_name]
while len(mb) > 0:
# Empty out the buffer
self.message_buffer[websocket][stream_name] = list()
await self.send_messages(
websocket=websocket, stream_name=stream_name, messages=mb
)
# OK we have actually sent everything pending. We can signal that from now on
# the messages can be sent immediately for this stream
self.message_buffer[websocket][stream_name] = None
async def unsubscribe_stream(self, *, websocket, data):
stream_name = data.pop("stream_name")
logger.info(f"Unsubscribed to {stream_name} from {websocket.origin}")
# silently ignore scenarios where the stream wasn't already subscribed to
self.subscriptions[stream_name].pop(websocket)
self.message_buffer[websocket].pop(stream_name)
async def get_streams(self, force_refetch=False):
if self.__streams is None or force_refetch:
logger.info("Forcing a refresh of the stream names")
self.__streams = await self.store.get_streams()
logger.debug(self.__streams)
return self.__streams
async def append_to_stream(self, *, stream_name, latest_id, messages):
tasks = [
self.store.append_messages(
stream_name=stream_name, latest_id=latest_id, messages=messages
)
]
# Publish these messages to all subscribers
# (or add them to their buffer if they aren't quite ready yet)
tasks.extend(
[
self.process_messages_for_user(
stream_name=stream_name, messages=messages, recipient=user
)
for user in self.subscriptions[stream_name]
]
)
# Update the DB and send the data to the clients in parallel
await asyncio.wait(tasks)
async def process_messages_for_user(self, *, stream_name, messages, recipient):
"""
Either push the messages straight out, or insert them in the buffer.
Whatever the case, apply their streamfilter first.
Note: to avoid race conditions, you need to make sure
to not call await between checking the buffer status
and operating on it. Otherwise it's possible a person
could miss some messages, if they are just finishing
processing their buffer.
"""
sf = self.subscriptions[stream_name].get(recipient)
buffer = self.message_buffer.get(recipient)
if sf is None or buffer is None or stream_name not in buffer:
# They must have just disconnected/unsubscribed
return
filtered_messages = list(sf.from_stream(messages=messages))
if buffer[stream_name] is None:
# no buffer, just push straight out
await self.send_messages(
websocket=recipient, messages=filtered_messages, stream_name=stream_name
)
else:
buffer[stream_name].extend(filtered_messages)
async def auth(self, *, websocket, data):
# TODO
return True
async def pong(self, *, websocket, data):
await websocket.send(json.dumps(dict(type="pong", **data)))
async def entrypoint(websocket, path, app):
# register(websocket) sends user_event() to websocket
await app.register(websocket)
action_map = dict(
auth=app.auth,
subscribe_stream=app.subscribe_stream,
request_resend=app.request_resend,
pong=app.pong,
unsubscribe_stream=app.unsubscribe_stream,
)
message = "?"
try:
# await websocket.send(state_event())
async for message in websocket:
data = json.loads(message)
action = data.get("action")
if action is None:
logger.error(f"No action in {data}")
continue
if action in action_map:
await action_map[action](websocket=websocket, data=data)
else:
logger.error(f"Unknown action: {action}")
except KeyboardInterrupt:
raise
except Exception:
client.captureException()
logger.exception(f"During processing of {message}")
finally:
await app.unregister(websocket)
async def redis_relay(*, app):
"""
Monitors the redis streams for new data.
For now, you'll have to restart the app if you add new streams to the DB.
"""
loop = asyncio.get_event_loop()
redis = await aioredis.create_redis_pool(
"redis://test.localhost", # fixme: config
minsize=5,
maxsize=10,
loop=loop,
db=0, # fixme: should be config variable
)
while True:
try:
streams = await app.get_streams()
if streams:
new_messages = defaultdict(list)
stream_names, latest_ids = zip(*streams.items())
with await redis as conn:
for stream_name, latest_id, content in await conn.xread(
list(stream_names),
timeout=1000 * 600, # milliseconds
count=10000,
latest_ids=list(latest_ids),
):
logger.debug(f"Appending to {stream_name}: {content}")
try:
stream_name = stream_name.decode("utf-8")
new_messages[stream_name].append(
unpack_message_content(content)
)
except Exception:
client.captureException()
logger.exception(f"Cannot process a message:\n{content}")
streams[stream_name] = latest_id.decode("utf-8")
for stream_name, messages in new_messages.items():
await app.append_to_stream(
stream_name=stream_name,
latest_id=streams[stream_name],
messages=messages,
)
else:
logger.warning(
"No streams found. Sleeping for a bit, then I'll try again."
)
await asyncio.sleep(60)
except Exception:
client.captureException()
logger.exception("oopsy")
raise
def unpack_message_content(content):
result = dict()
for k, v in content.items():
if k == b"payload":
v = json.loads(v)
else:
v = v.decode("utf-8")
result[k.decode("utf-8")] = v
return result
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int)
parser.add_argument("--randomstore", action="store_true")
parser.add_argument("--host", default="postgresql")
parser.add_argument("--user", default="postgres")
parser.add_argument("--password", default="secret")
parser.add_argument("--dbname", default="postgres")
args = parser.parse_args()
assert args.port > 1024
# TODO: read these from the commandline
kw = dict(
host=args.host, user=args.user, password=args.password, dbname=args.dbname
)
if "POSTGRES_PASSWORD" in environ:
kw["password"] = environ["POSTGRES_PASSWORD"]
store = RandomStore(**kw) if args.randomstore else Store(**kw)
app = App(store=store)
bound_entrypoint = partial(entrypoint, app=app)
loop = asyncio.get_event_loop()
loop.run_until_complete(app.async_init())
try:
tasks = [
websockets.serve(bound_entrypoint, "0.0.0.0", args.port),
redis_relay(app=app),
app.monitor(),
]
loop.run_until_complete(asyncio.wait(tasks))
finally:
pass
# if args.randomstore:
# loop.run_until_complete(store.destroy_db)
logger.info(f"Listening on port {args.port}")
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s %(levelname)s {%(module)s} [%(funcName)s] %(message)s",
datefmt="%Y-%m-%d,%H:%M:%S",
level=logging.INFO,
)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment