Created
October 15, 2012 09:13
-
-
Save mrjoes/3891601 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import sys | |
import types | |
from tornado.gen import (engine, YieldPoint, Multi, Task, Runner, Arguments, | |
LeakedCallbackError, BadYieldError) | |
from tornado.stack_context import ExceptionStackContext, wrap | |
def async(func): | |
@functools.wraps(func) | |
def wrapper(*args, **kwargs): | |
runner = None | |
def handle_exception(typ, value, tb): | |
if runner is not None: | |
return runner.handle_exception(typ, value, tb) | |
return False | |
# Extract callback, so it is not passed to the wrapped function | |
callback = None | |
if 'callback' in kwargs: | |
callback = wrap(kwargs.pop('callback')) | |
with ExceptionStackContext(handle_exception) as deactivate: | |
ret_value = None | |
# Execute function. If function raises StopIteration, it is same as if | |
# it returned something, but did not do anything asynchronously. | |
try: | |
gen = func(*args, **kwargs) | |
except StopIteration, ex: | |
if ex.args: | |
ret_value = ex.args | |
gen = None | |
except: | |
raise | |
if isinstance(gen, types.GeneratorType): | |
runner = AsyncRunner(gen, deactivate, callback) | |
runner.run() | |
return | |
# Not asynchronous, clean up, run callback | |
assert gen is None, gen | |
deactivate() | |
if callback is not None: | |
if ret_value: | |
callback(*ret_value[0], **ret_value[1]) | |
else: | |
callback(None) | |
return wrapper | |
class AsyncRunner(Runner): | |
def __init__(self, gen, deactivate, callback): | |
super(AsyncRunner, self).__init__(gen, deactivate) | |
self._callback = callback | |
def run(self): | |
"""Starts or resumes the generator, running until it reaches a | |
yield point that is not ready. | |
""" | |
if self.running or self.finished: | |
return | |
try: | |
self.running = True | |
while True: | |
if self.exc_info is None: | |
try: | |
if not self.yield_point.is_ready(): | |
return | |
next = self.yield_point.get_result() | |
except Exception: | |
self.exc_info = sys.exc_info() | |
try: | |
if self.exc_info is not None: | |
self.had_exception = True | |
exc_info = self.exc_info | |
self.exc_info = None | |
yielded = self.gen.throw(*exc_info) | |
else: | |
yielded = self.gen.send(next) | |
except StopIteration, ex: | |
self.finished = True | |
if self.pending_callbacks and not self.had_exception: | |
# If we ran cleanly without waiting on all callbacks | |
# raise an error (really more of a warning). If we | |
# had an exception then some callbacks may have been | |
# orphaned, so skip the check in that case. | |
raise LeakedCallbackError( | |
"finished without waiting for callbacks %r" % | |
self.pending_callbacks) | |
self.deactivate_stack_context() | |
# Run callback if it was provided | |
if self._callback is not None: | |
if ex.args: | |
self._callback(*ex.args[0], **ex.args[1]) | |
else: | |
self._callback(None) | |
return | |
except Exception: | |
self.finished = True | |
raise | |
if isinstance(yielded, list): | |
yielded = Multi(yielded) | |
if isinstance(yielded, YieldPoint): | |
self.yield_point = yielded | |
try: | |
self.yield_point.start(self) | |
except Exception: | |
self.exc_info = sys.exc_info() | |
else: | |
self.exc_info = (BadYieldError("yielded unknown object %r" % yielded),) | |
finally: | |
self.running = False | |
def ret(*args, **kwargs): | |
raise StopIteration(args, kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from time import time | |
from tornado.ioloop import IOLoop | |
import genmod as gen | |
def eq_(a, b): | |
assert a == b | |
def test_ping(): | |
@gen.async | |
def long_op(x, y): | |
yield gen.Task(io_loop.add_timeout, time() + 0.1) | |
gen.ret(x * y) | |
@gen.async | |
def sync_op(x, y): | |
gen.ret(x + y) | |
@gen.async | |
def sync_check(a): | |
if a == 10: | |
gen.ret(15) | |
@gen.async | |
def async_check(a): | |
yield gen.Task(io_loop.add_timeout, time() + 0.1) | |
if a == 10: | |
gen.ret(15) | |
gen.ret(20) | |
@gen.async | |
def raise_exc(): | |
raise Exception('test') | |
@gen.async | |
def args(a, b): | |
gen.ret(a, b) | |
@gen.async | |
def kwargs(a, b): | |
yield gen.Task(io_loop.add_timeout, time() + 0.1) | |
gen.ret(a=a, b=b) | |
@gen.async | |
def proc(): | |
# Make calls | |
a = yield gen.Task(long_op, 2, 3) | |
b = yield gen.Task(sync_op, a, 4) | |
eq_(b, 10) | |
# Check default return value (None) | |
res = yield gen.Task(sync_check, 10) | |
eq_(res, 15) | |
res = yield gen.Task(sync_check, 15) | |
eq_(res, None) | |
# Async check | |
res = yield gen.Task(async_check, 10) | |
eq_(res, 15) | |
res = yield gen.Task(async_check, 15) | |
eq_(res, 20) | |
# Exception check | |
try: | |
res = yield gen.Task(raise_exc) | |
except Exception, ex: | |
eq_(ex.message, 'test') | |
# Result tuple | |
res = yield gen.Task(args, 10, 20) | |
eq_(res.args, (10, 20)) | |
# Named arguments | |
res = yield gen.Task(kwargs, 10, 20) | |
eq_(res.kwargs['a'], 10) | |
eq_(res.kwargs['b'], 20) | |
# Finish test | |
io_loop.stop() | |
io_loop = IOLoop.instance() | |
io_loop.add_callback(proc) | |
io_loop.start() | |
if __name__ == '__main__': | |
test_ping() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment