Skip to content

Instantly share code, notes, and snippets.

@Snawoot
Created June 25, 2019 16:03
Show Gist options
  • Save Snawoot/1b0cb9f448f721c233062574dd5f58c3 to your computer and use it in GitHub Desktop.
Save Snawoot/1b0cb9f448f721c233062574dd5f58c3 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# FOR EDUCATIONAL PURPOSES ONLY
#
import asyncio
import argparse
import logging
import logging.handlers
import queue
import enum
import signal
import base64
import binascii
from functools import partial
import pysodium
def base64_decode(s):
"""Add missing padding to string and return the decoded base64 string."""
s = str(s).strip()
try:
return base64.b64decode(s)
except (TypeError, binascii.Error):
padding = len(s) % 4
if padding == 1:
raise ValueError("Bad base64 string")
elif padding == 2:
s += '=='
elif padding == 3:
s += '='
return base64.b64decode(s)
class Receiver: # pylint: disable=too-many-instance-attributes
def __init__(self, *,
listen_address,
listen_port,
key,
loop=None):
self._loop = loop if loop is not None else asyncio.get_event_loop()
self._logger = logging.getLogger(self.__class__.__name__)
self._listen_address = listen_address
self._listen_port = listen_port
self._children = set()
self._server = None
self._sk = base64_decode(key)
self._pk = pysodium.crypto_scalarmult_curve25519_base(self._sk)
self._logger.debug("Computed PK = %s", base64.b64encode(self._pk))
async def stop(self):
self._server.close()
await self._server.wait_closed()
while self._children:
children = list(self._children)
self._children.clear()
self._logger.debug("Cancelling %d client handlers...",
len(children))
for task in children:
task.cancel()
await asyncio.wait(children)
# workaround for TCP server keeps spawning handlers for a while
# after wait_closed() completed
await asyncio.sleep(.5)
async def handler(self, reader, writer):
peer_addr = writer.transport.get_extra_info('peername')
self._logger.info("Client %s connected", str(peer_addr))
try:
data = b''
while True:
buf = await reader.read(4096)
if not buf:
break
data += buf
self._logger.debug("Received: %s", repr(data))
try:
dec = pysodium.crypto_box_seal_open(data, self._pk, self._sk)
except:
self._logger.error("Decryption failed for data: %s", repr(data))
else:
self._logger.warning("Extracted secret data: %s", repr(dec))
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as exc: # pragma: no cover
self._logger.exception("Connection handler stopped with exception:"
" %s", str(exc))
finally:
self._logger.info("Client %s disconnected", str(peer_addr))
writer.close()
async def start(self):
def _spawn(reader, writer):
def task_cb(task, fut):
self._children.discard(task)
task = self._loop.create_task(self.handler(reader, writer))
self._children.add(task)
task.add_done_callback(partial(task_cb, task))
self._server = await asyncio.start_server(_spawn,
self._listen_address,
self._listen_port)
self._logger.info("Server ready.")
class LogLevel(enum.IntEnum):
debug = logging.DEBUG
info = logging.INFO
warn = logging.WARN
error = logging.ERROR
fatal = logging.FATAL
crit = logging.CRITICAL
def __str__(self):
return self.name
class OverflowingQueue(queue.Queue):
def put(self, item, block=True, timeout=None):
try:
return queue.Queue.put(self, item, block, timeout)
except queue.Full:
pass
def put_nowait(self, item):
return self.put(item, False)
class AsyncLoggingHandler:
def __init__(self, logfile=None, maxsize=1024):
_queue = OverflowingQueue(maxsize)
if logfile is None:
_handler = logging.StreamHandler()
else:
_handler = logging.FileHandler(logfile)
self._listener = logging.handlers.QueueListener(_queue, _handler)
self._async_handler = logging.handlers.QueueHandler(_queue)
_handler.setFormatter(logging.Formatter('%(asctime)s '
'%(levelname)-8s '
'%(name)s: %(message)s',
'%Y-%m-%d %H:%M:%S'))
def __enter__(self):
self._listener.start()
return self._async_handler
def __exit__(self, exc_type, exc_value, traceback):
self._listener.stop()
def setup_logger(name, verbosity, handler):
logger = logging.getLogger(name)
logger.setLevel(verbosity)
logger.addHandler(handler)
return logger
def parse_args():
def check_port(value):
def fail():
raise argparse.ArgumentTypeError(
"%s is not a valid port number" % value)
try:
ivalue = int(value)
except ValueError:
fail()
if not 0 < ivalue < 65536:
fail()
return ivalue
def check_loglevel(arg):
try:
return LogLevel[arg]
except (IndexError, KeyError):
raise argparse.ArgumentTypeError("%s is not valid loglevel" % (repr(arg),))
parser = argparse.ArgumentParser(
description="SECret RECeiVer",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-v", "--verbosity",
help="logging verbosity",
type=check_loglevel,
choices=LogLevel,
default=LogLevel.info)
parser.add_argument("-l", "--logfile",
help="log file location",
metavar="FILE")
parser.add_argument("key",
help="secret key (base64-encoded without padding)")
listen_group = parser.add_argument_group('listen options')
listen_group.add_argument("-a", "--bind-address",
default="0.0.0.0",
help="bind address")
listen_group.add_argument("-p", "--bind-port",
default=16684,
type=check_port,
help="bind port")
return parser.parse_args()
def exit_handler(exit_event, signum, frame): # pragma: no cover pylint: disable=unused-argument
logger = logging.getLogger('MAIN')
if exit_event.is_set():
logger.warning("Got second exit signal! Terminating hard.")
os._exit(1) # pylint: disable=protected-access
else:
logger.warning("Got first exit signal! Terminating gracefully.")
exit_event.set()
async def heartbeat():
""" Hacky coroutine which keeps event loop spinning with some interval
even if no events are coming. This is required to handle Futures and
Events state change when no events are occuring."""
while True:
await asyncio.sleep(.5)
async def amain(args, loop): # pragma: no cover
logger = logging.getLogger('MAIN')
server = Receiver(listen_address=args.bind_address,
listen_port=args.bind_port,
key=args.key,
loop=loop)
await server.start()
logger.info("Server started.")
exit_event = asyncio.Event()
beat = asyncio.ensure_future(heartbeat())
sig_handler = partial(exit_handler, exit_event)
signal.signal(signal.SIGTERM, sig_handler)
signal.signal(signal.SIGINT, sig_handler)
await exit_event.wait()
beat.cancel()
await server.stop()
def main(): # pragma: no cover
args = parse_args()
with AsyncLoggingHandler(args.logfile) as log_handler:
logger = setup_logger('MAIN', args.verbosity, log_handler)
setup_logger(Receiver.__name__, args.verbosity, log_handler)
loop = asyncio.get_event_loop()
loop.run_until_complete(amain(args, loop))
loop.close()
logger.info("Server finished its work.")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment