-
-
Save Snawoot/1b0cb9f448f721c233062574dd5f58c3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pysodium |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# FOR EDUCATIONAL PURPOSES ONLY | |
# | |
import asyncio | |
import argparse | |
import logging | |
import logging.handlers | |
import queue | |
import enum | |
import signal | |
import base64 | |
import binascii | |
from functools import partial | |
import pysodium | |
def base64_decode(s): | |
"""Add missing padding to string and return the decoded base64 string.""" | |
s = str(s).strip() | |
try: | |
return base64.b64decode(s) | |
except (TypeError, binascii.Error): | |
padding = len(s) % 4 | |
if padding == 1: | |
raise ValueError("Bad base64 string") | |
elif padding == 2: | |
s += '==' | |
elif padding == 3: | |
s += '=' | |
return base64.b64decode(s) | |
class Receiver: # pylint: disable=too-many-instance-attributes | |
def __init__(self, *, | |
listen_address, | |
listen_port, | |
key, | |
loop=None): | |
self._loop = loop if loop is not None else asyncio.get_event_loop() | |
self._logger = logging.getLogger(self.__class__.__name__) | |
self._listen_address = listen_address | |
self._listen_port = listen_port | |
self._children = set() | |
self._server = None | |
self._sk = base64_decode(key) | |
self._pk = pysodium.crypto_scalarmult_curve25519_base(self._sk) | |
self._logger.debug("Computed PK = %s", base64.b64encode(self._pk)) | |
async def stop(self): | |
self._server.close() | |
await self._server.wait_closed() | |
while self._children: | |
children = list(self._children) | |
self._children.clear() | |
self._logger.debug("Cancelling %d client handlers...", | |
len(children)) | |
for task in children: | |
task.cancel() | |
await asyncio.wait(children) | |
# workaround for TCP server keeps spawning handlers for a while | |
# after wait_closed() completed | |
await asyncio.sleep(.5) | |
async def handler(self, reader, writer): | |
peer_addr = writer.transport.get_extra_info('peername') | |
self._logger.info("Client %s connected", str(peer_addr)) | |
try: | |
data = b'' | |
while True: | |
buf = await reader.read(4096) | |
if not buf: | |
break | |
data += buf | |
self._logger.debug("Received: %s", repr(data)) | |
try: | |
dec = pysodium.crypto_box_seal_open(data, self._pk, self._sk) | |
except: | |
self._logger.error("Decryption failed for data: %s", repr(data)) | |
else: | |
self._logger.warning("Extracted secret data: %s", repr(dec)) | |
except asyncio.CancelledError: # pylint: disable=try-except-raise | |
raise | |
except Exception as exc: # pragma: no cover | |
self._logger.exception("Connection handler stopped with exception:" | |
" %s", str(exc)) | |
finally: | |
self._logger.info("Client %s disconnected", str(peer_addr)) | |
writer.close() | |
async def start(self): | |
def _spawn(reader, writer): | |
def task_cb(task, fut): | |
self._children.discard(task) | |
task = self._loop.create_task(self.handler(reader, writer)) | |
self._children.add(task) | |
task.add_done_callback(partial(task_cb, task)) | |
self._server = await asyncio.start_server(_spawn, | |
self._listen_address, | |
self._listen_port) | |
self._logger.info("Server ready.") | |
class LogLevel(enum.IntEnum): | |
debug = logging.DEBUG | |
info = logging.INFO | |
warn = logging.WARN | |
error = logging.ERROR | |
fatal = logging.FATAL | |
crit = logging.CRITICAL | |
def __str__(self): | |
return self.name | |
class OverflowingQueue(queue.Queue): | |
def put(self, item, block=True, timeout=None): | |
try: | |
return queue.Queue.put(self, item, block, timeout) | |
except queue.Full: | |
pass | |
def put_nowait(self, item): | |
return self.put(item, False) | |
class AsyncLoggingHandler: | |
def __init__(self, logfile=None, maxsize=1024): | |
_queue = OverflowingQueue(maxsize) | |
if logfile is None: | |
_handler = logging.StreamHandler() | |
else: | |
_handler = logging.FileHandler(logfile) | |
self._listener = logging.handlers.QueueListener(_queue, _handler) | |
self._async_handler = logging.handlers.QueueHandler(_queue) | |
_handler.setFormatter(logging.Formatter('%(asctime)s ' | |
'%(levelname)-8s ' | |
'%(name)s: %(message)s', | |
'%Y-%m-%d %H:%M:%S')) | |
def __enter__(self): | |
self._listener.start() | |
return self._async_handler | |
def __exit__(self, exc_type, exc_value, traceback): | |
self._listener.stop() | |
def setup_logger(name, verbosity, handler): | |
logger = logging.getLogger(name) | |
logger.setLevel(verbosity) | |
logger.addHandler(handler) | |
return logger | |
def parse_args(): | |
def check_port(value): | |
def fail(): | |
raise argparse.ArgumentTypeError( | |
"%s is not a valid port number" % value) | |
try: | |
ivalue = int(value) | |
except ValueError: | |
fail() | |
if not 0 < ivalue < 65536: | |
fail() | |
return ivalue | |
def check_loglevel(arg): | |
try: | |
return LogLevel[arg] | |
except (IndexError, KeyError): | |
raise argparse.ArgumentTypeError("%s is not valid loglevel" % (repr(arg),)) | |
parser = argparse.ArgumentParser( | |
description="SECret RECeiVer", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument("-v", "--verbosity", | |
help="logging verbosity", | |
type=check_loglevel, | |
choices=LogLevel, | |
default=LogLevel.info) | |
parser.add_argument("-l", "--logfile", | |
help="log file location", | |
metavar="FILE") | |
parser.add_argument("key", | |
help="secret key (base64-encoded without padding)") | |
listen_group = parser.add_argument_group('listen options') | |
listen_group.add_argument("-a", "--bind-address", | |
default="0.0.0.0", | |
help="bind address") | |
listen_group.add_argument("-p", "--bind-port", | |
default=16684, | |
type=check_port, | |
help="bind port") | |
return parser.parse_args() | |
def exit_handler(exit_event, signum, frame): # pragma: no cover pylint: disable=unused-argument | |
logger = logging.getLogger('MAIN') | |
if exit_event.is_set(): | |
logger.warning("Got second exit signal! Terminating hard.") | |
os._exit(1) # pylint: disable=protected-access | |
else: | |
logger.warning("Got first exit signal! Terminating gracefully.") | |
exit_event.set() | |
async def heartbeat(): | |
""" Hacky coroutine which keeps event loop spinning with some interval | |
even if no events are coming. This is required to handle Futures and | |
Events state change when no events are occuring.""" | |
while True: | |
await asyncio.sleep(.5) | |
async def amain(args, loop): # pragma: no cover | |
logger = logging.getLogger('MAIN') | |
server = Receiver(listen_address=args.bind_address, | |
listen_port=args.bind_port, | |
key=args.key, | |
loop=loop) | |
await server.start() | |
logger.info("Server started.") | |
exit_event = asyncio.Event() | |
beat = asyncio.ensure_future(heartbeat()) | |
sig_handler = partial(exit_handler, exit_event) | |
signal.signal(signal.SIGTERM, sig_handler) | |
signal.signal(signal.SIGINT, sig_handler) | |
await exit_event.wait() | |
beat.cancel() | |
await server.stop() | |
def main(): # pragma: no cover | |
args = parse_args() | |
with AsyncLoggingHandler(args.logfile) as log_handler: | |
logger = setup_logger('MAIN', args.verbosity, log_handler) | |
setup_logger(Receiver.__name__, args.verbosity, log_handler) | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(amain(args, loop)) | |
loop.close() | |
logger.info("Server finished its work.") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment