Skip to content

Instantly share code, notes, and snippets.

@sorcio
Last active May 15, 2018 17:10
Show Gist options
  • Save sorcio/fe8bb54388cda36172e9f67dd9bb8cee to your computer and use it in GitHub Desktop.
Save sorcio/fe8bb54388cda36172e9f67dd9bb8cee to your computer and use it in GitHub Desktop.
sync -> trio magic wrapper (proof of concept)
from contextlib import contextmanager
from functools import partial
import threading
import types
from queue import Queue
import sys
import outcome
import trio
async def wrap_it_dangerously(callable, *args):
"""
Execute some sync code making it more Trio-friendly by magically
transforming some blocking calls into checkpoints.
"""
wrapper = SyncWrapper(callable, args)
wrapper.start()
_next_value = None
while True:
next_step = wrapper.send(_next_value) # blocking
if wrapper.complete:
return next_step.unwrap()
else:
_next_value = await outcome.acapture(next_step)
class SyncWrapper(threading.Thread):
def __init__(self, callable, args):
threading.Thread.__init__(self)
self.callable = callable
self.args = args
self.in_queue = Queue()
self.out_queue = Queue()
self.complete = False
def run(self):
_portal.portal = self
self.in_queue.get() # blocking
result = outcome.capture(self.callable, *self.args)
self.complete = True
self.out_queue.put(result)
def send(self, value):
with monkeypatched_context():
# reschedule thread...
self.in_queue.put(value)
# ...and block until next virtual checkpoint...
next_step = self.out_queue.get() # blocking
return next_step
_portal = threading.local()
def make_sync(afn):
def inner(*args, **kwargs):
portal = _portal.portal
portal.out_queue.put(partial(afn, *args, **kwargs))
return portal.in_queue.get().unwrap() # blocking
return inner
class PortalizedSocket:
@make_sync
async def __init__(self, *args, **kwargs):
# well, trio.socket.socket() needs an async context so
# we make a very spurious async __init__()... don't judge me.
self._socket = trio.socket.socket(*args, **kwargs)
@make_sync
async def connect(self, *args, **kwargs):
await self._socket.connect(*args, **kwargs)
@make_sync
async def send(self, *args, **kwargs):
await self._socket.send(*args, **kwargs)
@make_sync
async def recv(self, *args, **kwargs):
return await self._socket.recv(*args, **kwargs)
@make_sync
async def close(self, *args, **kwargs):
self._socket.close(*args, **kwargs)
@make_sync
async def portalized_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
return await trio.socket.getaddrinfo(host, port, family, type, proto, flags)
_monkeypatches = {
'socket': {
'getaddrinfo': portalized_getaddrinfo,
'socket': PortalizedSocket,
'AF_INET': trio.socket.AF_INET,
'SOCK_STREAM': trio.socket.SOCK_STREAM,
},
}
def _prepare_monkeypatched_modules():
for module_name, namespace in _monkeypatches.items():
module = types.ModuleType('module_name')
for name, value in namespace.items():
setattr(module, name, value)
yield module_name, module
_patched_modules = dict(_prepare_monkeypatched_modules())
@contextmanager
def monkeypatched_context():
old_modules = {}
for module_name, module in _patched_modules.items():
old_modules[module_name] = sys.modules.get(module_name, None)
sys.modules[module_name] = module
yield
for module_name, module in old_modules.items():
if module:
sys.modules[module_name] = module
else:
del sys.modules[module_name]
def some_sync_code(i):
import socket
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
# GET http://httpbin.org/get
addrinfo = socket.getaddrinfo('httpbin.org', 80, socket.AF_INET, socket.SOCK_STREAM)
sockaddr = addrinfo[0][4]
print(f'[{i}] connecting to {sockaddr}')
sock.connect(sockaddr)
print(f'[{i}] connected')
sock.send(b'GET /get HTTP/1.1\r\nHost: httpbin.org\r\nConnection: close\r\n\r\n')
print(f'[{i}] sent bytes')
buf = bytearray()
while True:
chunk = sock.recv(1024)
if not chunk:
break
buf += chunk
print(f'[{i}] connection closed')
sock.close()
return buf
if __name__ == '__main__':
print('== sync version:')
print(some_sync_code('sync').decode('ascii'))
print('== once again, but async-wrapped:')
async def call_it_like_it_were_async(i, nursery):
result = await wrap_it_dangerously(some_sync_code, i)
# only print one result, and test cancellation while we're at it
nursery.cancel_scope.cancel()
print(result.decode('ascii'))
async def amain():
async with trio.open_nursery() as nursery:
for i in range(10):
nursery.start_soon(call_it_like_it_were_async, i, nursery)
trio.run(amain)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment