Last active
May 15, 2018 17:10
-
-
Save sorcio/fe8bb54388cda36172e9f67dd9bb8cee to your computer and use it in GitHub Desktop.
sync -> trio magic wrapper (proof of concept)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import contextmanager | |
from functools import partial | |
import threading | |
import types | |
from queue import Queue | |
import sys | |
import outcome | |
import trio | |
async def wrap_it_dangerously(callable, *args): | |
""" | |
Execute some sync code making it more Trio-friendly by magically | |
transforming some blocking calls into checkpoints. | |
""" | |
wrapper = SyncWrapper(callable, args) | |
wrapper.start() | |
_next_value = None | |
while True: | |
next_step = wrapper.send(_next_value) # blocking | |
if wrapper.complete: | |
return next_step.unwrap() | |
else: | |
_next_value = await outcome.acapture(next_step) | |
class SyncWrapper(threading.Thread): | |
def __init__(self, callable, args): | |
threading.Thread.__init__(self) | |
self.callable = callable | |
self.args = args | |
self.in_queue = Queue() | |
self.out_queue = Queue() | |
self.complete = False | |
def run(self): | |
_portal.portal = self | |
self.in_queue.get() # blocking | |
result = outcome.capture(self.callable, *self.args) | |
self.complete = True | |
self.out_queue.put(result) | |
def send(self, value): | |
with monkeypatched_context(): | |
# reschedule thread... | |
self.in_queue.put(value) | |
# ...and block until next virtual checkpoint... | |
next_step = self.out_queue.get() # blocking | |
return next_step | |
_portal = threading.local() | |
def make_sync(afn): | |
def inner(*args, **kwargs): | |
portal = _portal.portal | |
portal.out_queue.put(partial(afn, *args, **kwargs)) | |
return portal.in_queue.get().unwrap() # blocking | |
return inner | |
class PortalizedSocket: | |
@make_sync | |
async def __init__(self, *args, **kwargs): | |
# well, trio.socket.socket() needs an async context so | |
# we make a very spurious async __init__()... don't judge me. | |
self._socket = trio.socket.socket(*args, **kwargs) | |
@make_sync | |
async def connect(self, *args, **kwargs): | |
await self._socket.connect(*args, **kwargs) | |
@make_sync | |
async def send(self, *args, **kwargs): | |
await self._socket.send(*args, **kwargs) | |
@make_sync | |
async def recv(self, *args, **kwargs): | |
return await self._socket.recv(*args, **kwargs) | |
@make_sync | |
async def close(self, *args, **kwargs): | |
self._socket.close(*args, **kwargs) | |
@make_sync | |
async def portalized_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): | |
return await trio.socket.getaddrinfo(host, port, family, type, proto, flags) | |
_monkeypatches = { | |
'socket': { | |
'getaddrinfo': portalized_getaddrinfo, | |
'socket': PortalizedSocket, | |
'AF_INET': trio.socket.AF_INET, | |
'SOCK_STREAM': trio.socket.SOCK_STREAM, | |
}, | |
} | |
def _prepare_monkeypatched_modules(): | |
for module_name, namespace in _monkeypatches.items(): | |
module = types.ModuleType('module_name') | |
for name, value in namespace.items(): | |
setattr(module, name, value) | |
yield module_name, module | |
_patched_modules = dict(_prepare_monkeypatched_modules()) | |
@contextmanager | |
def monkeypatched_context(): | |
old_modules = {} | |
for module_name, module in _patched_modules.items(): | |
old_modules[module_name] = sys.modules.get(module_name, None) | |
sys.modules[module_name] = module | |
yield | |
for module_name, module in old_modules.items(): | |
if module: | |
sys.modules[module_name] = module | |
else: | |
del sys.modules[module_name] | |
def some_sync_code(i): | |
import socket | |
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM) | |
# GET http://httpbin.org/get | |
addrinfo = socket.getaddrinfo('httpbin.org', 80, socket.AF_INET, socket.SOCK_STREAM) | |
sockaddr = addrinfo[0][4] | |
print(f'[{i}] connecting to {sockaddr}') | |
sock.connect(sockaddr) | |
print(f'[{i}] connected') | |
sock.send(b'GET /get HTTP/1.1\r\nHost: httpbin.org\r\nConnection: close\r\n\r\n') | |
print(f'[{i}] sent bytes') | |
buf = bytearray() | |
while True: | |
chunk = sock.recv(1024) | |
if not chunk: | |
break | |
buf += chunk | |
print(f'[{i}] connection closed') | |
sock.close() | |
return buf | |
if __name__ == '__main__': | |
print('== sync version:') | |
print(some_sync_code('sync').decode('ascii')) | |
print('== once again, but async-wrapped:') | |
async def call_it_like_it_were_async(i, nursery): | |
result = await wrap_it_dangerously(some_sync_code, i) | |
# only print one result, and test cancellation while we're at it | |
nursery.cancel_scope.cancel() | |
print(result.decode('ascii')) | |
async def amain(): | |
async with trio.open_nursery() as nursery: | |
for i in range(10): | |
nursery.start_soon(call_it_like_it_were_async, i, nursery) | |
trio.run(amain) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment