Skip to content

Instantly share code, notes, and snippets.

@akruis
Created February 10, 2021 16:29
Show Gist options
  • Save akruis/1c6c89777d176ed857cd67ed12f0fc71 to your computer and use it in GitHub Desktop.
Save akruis/1c6c89777d176ed857cd67ed12f0fc71 to your computer and use it in GitHub Desktop.
A prove of concept for pickling / unpickling an asyncio.Task object
'''
A prove of concept for pickling / unpickling an asyncio.Task object
Limitations of this demo code:
- Requires Stackless Python 3.7, because this Python implementation can pickle coroutine objects.
- Uses the pure python task implementation asyncio.tasks._PyTask
Copyright (C) 2021 Anselm Kruis
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
import sys
import asyncio
import logging
import pickle
import pickletools
import copyreg
import contextvars
import io
# Extend pickling support a bit. Not perfect, just enough for the prove of concept
# pickling for objects of type asyncio.tasks._PyTask
def reduce_asyncio_tasks__PyTask(task):
# attributes of _PyTask
# _loop # from _PyFuture
# _source_traceback # init with format_helpers.extract_stack(sys._getframe(1)) # _PyFuture
# _log_destroy_pending # not always present
# _must_cancel
# _fut_waiter
# _coro
# _context
return (object.__new__, (asyncio.tasks._PyTask,), task.__dict__)
copyreg.pickle(asyncio.tasks._PyTask, reduce_asyncio_tasks__PyTask)
# pickling for objects of type contextvars.Context
def reduce_contextvars_Context(context):
return (contextvars.copy_context, ()) # unpickle as copy of the current context during unpickling
copyreg.pickle(contextvars.Context, reduce_contextvars_Context)
# # In order to support pickling of the C implementation of asyncio.Task, more effort is required.
# # Probably we need an extension module
#
# # pickling for objects of type _asyncio.Task / asyncio.tasks._CTask
# def reduce__asyncio_Task(task):
# raise pickle.PicklingError("can't pickel {} objects".format(str(type(task))))
# copyreg.pickle(asyncio.tasks._CTask, reduce__asyncio_Task)
#
# # pickling for objects of type _asyncio.FutureIter
# def getFutureIter():
# # Create a new loop and destroy it later on. If this has side effects,
# # we can run this function as a separate thread.
# loop = asyncio.new_event_loop()
# try:
# return type(asyncio.Future(loop=loop).__iter__())
# finally:
# loop.close()
#
# def reduce__asyncio_FutureIter(fut_iter):
# raise pickle.PicklingError("can't pickel {} objects".format(str(type(fut_iter))))
# copyreg.pickle(getFutureIter(), reduce__asyncio_FutureIter)
# del getFutureIter
class PicklerWithExternalObjects(pickle._Pickler):
def __init__(self, file, protocol=None, *, fix_imports=True, external_map=None):
super().__init__(file, protocol, fix_imports=fix_imports)
self.external_map = external_map
def persistent_id(self, obj):
if self.external_map:
return self.external_map.get(id(obj))
return None
@classmethod
def dumps(cls, obj, protocol=None, *, fix_imports=True, external_map=None):
f = io.BytesIO()
cls(f, protocol, fix_imports=fix_imports, external_map=external_map).dump(obj)
res = f.getvalue()
assert isinstance(res, (bytes, bytearray))
return res
class UnPicklerWithExternalObjects(pickle.Unpickler):
def __init__(self, file, *, fix_imports=True, encoding="ASCII", errors="strict", external_map=None):
super().__init__(file, fix_imports=fix_imports, encoding=encoding, errors=errors)
self.external_map = external_map
def persistent_load(self, pid):
# This method is invoked whenever a persistent ID is encountered.
# Here, pid is the tuple returned by DBPickler.
if self.external_map is None:
raise pickle.UnpicklingError("external_map not set")
try:
return self.external_map[pid]
except KeyError:
# Always raises an error if you cannot return the correct object.
# Otherwise, the unpickler will think None is the object referenced
# by the persistent ID.
raise pickle.UnpicklingError("unsupported persistent object")
def persistent_id(self, obj):
if self.external_map:
return self.external_map.get(id(obj))
return None
@classmethod
def loads(cls, s, *, fix_imports=True, encoding="ASCII", errors="strict", external_map=None):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
return cls(file, fix_imports=fix_imports, encoding=encoding, errors=errors, external_map=external_map).load()
# A portable extension of asyncio.Loop. Add the method create_py_task(self, coro)
class FgEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
def __init__(self):
super().__init__()
loop = super().new_event_loop()
loop.close()
class FgEventLoop(type(loop)):
def create_py_task(self, coro):
self._check_closed()
task = asyncio.tasks._PyTask(coro, loop=self)
if task._source_traceback:
del task._source_traceback[-1]
return task
self.fg_event_loop_factory = FgEventLoop
def new_event_loop(self):
return self.fg_event_loop_factory()
# Flow Controller
class _FlowControllerDoneFutureSurrogate(object):
'''A surrogate of a asyncio.Future, that is done.
Instances of this class are used to inject the result of an
API call (i.e. FlowController.serialize()) into the unpickled flow.
'''
def __init__(self, result, exception=None):
self._result = result
self._exception = exception
def done(self):
return True
def result(self):
if self._exception is not None:
raise self._exception
return self._result
class FlowController(object):
"""
Knows the tasks of a single flow
Is responsible for
- starting
- serializing
- deserializing
Works asynchronously so it can be used in an asyncio program.
The flowController is stored in the context variable `current`.
"""
current = contextvars.ContextVar('flowController')
def __init__(self):
self._controller_task = None # not really required, just used in assertions
self._task = None
self._ignore_cancel = False
self._result = None
self._exception = None
@classmethod
async def serialize(cls, do_stop):
"""Serialize a flow task.
This class method must be awaited from the flow task to be serialized.
"""
self = cls.current.get() # get the FlowController from the context
loop = asyncio.get_running_loop()
# switch task to make the current task inactive
return await loop.create_py_task(self._serialize_impl(do_stop, asyncio.current_task()))
async def _serialize_impl(self, do_stop, task):
assert task is self._task
assert self._controller_task._fut_waiter is task
assert task._fut_waiter is asyncio.current_task()
loop = asyncio.get_running_loop()
callbacks = task._callbacks[:]
try:
for c in callbacks:
task.remove_done_callback(c[0])
p = PicklerWithExternalObjects.dumps(task,
external_map={id(loop): 'EventLoop',
id(task._fut_waiter): 'CurrentFuture', # see assert
id(self): 'FlowController',
id(asyncio.tasks._PyTask): 'GLOBAL asyncio.tasks._PyTask'})
finally:
for c in callbacks:
task.add_done_callback(c[0], context=c[1])
if do_stop:
self._result = p
self._ignore_cancel = True
task.cancel()
return p
async def start_flow(self, corofunc, *args, **kwargs):
"""Start a flow task.
"""
loop = asyncio.get_running_loop()
token = self.current.set(self)
self._controller_task = asyncio.current_task()
try:
self._task = loop.create_py_task(corofunc(*args, **kwargs))
finally:
self.current.reset(token)
return await self._run_flow()
async def continue_flow(self, pickle_bytes, send=None, throw=None):
"""Continue a pickled flow task
"""
assert self._task is None
loop = asyncio.get_running_loop()
self._controller_task = asyncio.current_task()
token = self.current.set(self)
try:
done_fut = _FlowControllerDoneFutureSurrogate(send, throw)
task = UnPicklerWithExternalObjects.loads(pickle_bytes, external_map={'EventLoop': loop,
'CurrentFuture': done_fut,
'FlowController': self,
'GLOBAL asyncio.tasks._PyTask': asyncio.tasks._PyTask})
assert task._fut_waiter is done_fut
task._fut_waiter = None # call_soon sets this value to self._controller_task
asyncio.tasks._register_task(task) # unpickling does not register it
loop.call_soon(task._Task__step, context=task._context)
# now task is a valid task, scheduled and ready to run
self._task = task
finally:
self.current.reset(token)
return await self._run_flow()
async def _run_flow(self):
'''Backend for `start_flow` and `continue_flow`'''
try:
return await self._task
except asyncio.CancelledError:
if self._ignore_cancel:
print("start: CancelledError ignored")
if self._exception is not None:
raise self._exception
return self._result
raise # rethrow
# Flow Definition
async def flow_main_function(do_stop):
print("flow_main_function: start")
res = await flow_sub_function(do_stop)
print("flow_main_function: end")
return res
async def flow_sub_function(do_stop):
print("flow_sub_function: start")
res = await FlowController.serialize(do_stop)
print("flow_sub_function: end")
return res
# Main
async def amain(*args):
controller = FlowController()
res = await controller.start_flow(flow_main_function, False)
print('Flow 1 done, _result type: ', type(res))
# res = pickletools.optimize(res) ; pickletools.dis(res)
controller2 = FlowController()
res = await controller2.continue_flow(res, 4711)
print('Flow 2 done, _result type: ', type(res))
assert res == 4711
def main(*args):
logging.basicConfig(level=logging.DEBUG)
print("start")
asyncio.set_event_loop_policy(FgEventLoopPolicy())
asyncio.run(amain(*args), debug=True)
print("end")
if __name__ == '__main__':
sys.exit(main(*sys.argv))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment