-
-
Save LiraNuna/a90b6ac3cc820e07022787ab07204f9a to your computer and use it in GitHub Desktop.
asynq dict fanout bug
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import asynq | |
import multiget_cache | |
from multiget_cache import base_cache_wrapper | |
from multiget_cache import function_tools | |
class AsyncCacheWrapper(base_cache_wrapper.BaseCacheWrapper): | |
instances = [] | |
def __init__(self, inner_f, object_key, argument_key=None, default_result=None, | |
result_key=None, object_tuple_key=None, coerce_args_to_strings=False): | |
self.object_key = object_key | |
self.result_key = result_key | |
self.default_result = default_result | |
self.object_tuple_key = object_tuple_key | |
self.coerce_args_to_strings = coerce_args_to_strings | |
self.argument_key = argument_key | |
if not argument_key: | |
self.argument_key = function_tools.get_kwargs_for_function(inner_f) | |
self.batch = _MultigetBatch(self) | |
super().__init__(inner_f) | |
AsyncCacheWrapper.instances.append(self) | |
def _make_batch_item(self, kwargs): | |
return _MultigetBatchItem(self.batch, kwargs) | |
def __call__(self, *args, **kwargs): | |
kwargs.update(function_tools.convert_args_to_kwargs(self.inner_f, args)) | |
# We are forced to flush the batch to get our result. Before flushing, we must add the current parameters | |
# to the batch (using prime) in case it was not already primed. | |
self.prime(*args, **kwargs) | |
self.batch.flush() | |
# Since flushing saves to cache, we should be able to retrieve the result from there | |
cache = multiget_cache.get_cache() | |
cache_key = self.get_cache_key(**kwargs) | |
return cache[cache_key] | |
def prime(self, *args, **kwargs): | |
# For backwards compatibility with async_cached. | |
# Since `async` does what we want, and we want to get rid of `prime`, just call `async`. | |
self.batch.items.append(self.async(*args, **kwargs)) | |
def async(self, *args, **kwargs): | |
kwargs.update(function_tools.convert_args_to_kwargs(self.inner_f, args)) | |
return self._make_batch_item(kwargs) | |
class _MultigetBatch(asynq.BatchBase): | |
def __init__(self, cache_wrapper): | |
self.cache_wrapper = cache_wrapper | |
super().__init__() | |
def _try_switch_active_batch(self): | |
if self.cache_wrapper.batch is self: | |
self.cache_wrapper.batch = _MultigetBatch(self.cache_wrapper) | |
def _flush(self): | |
cache = multiget_cache.get_cache() | |
kwargs_dict = defaultdict(list) | |
unique_kwargs = [] | |
# We only want kwargs that we will fetch, and aren't in the cache already. | |
# Since batching can fetch duplicate things, we deduplicate so we | |
# only fetch things we need that aren't in the cache. | |
for item in self.items: | |
cache_key = self.cache_wrapper.get_cache_key(**item.kwargs) | |
if cache_key in cache: | |
continue | |
# We can't hash dicts, we have to do this O(n) comparison. | |
if item.kwargs not in unique_kwargs: | |
unique_kwargs.append(item.kwargs) | |
# Convert batch items to arguments. | |
for kwargs in unique_kwargs: | |
for key, value in kwargs.items(): | |
# sometimes we get a list?? | |
if hasattr(value, 'append'): | |
value = value[0] | |
if self.cache_wrapper.coerce_args_to_strings: | |
value = str(value) | |
kwargs_dict[key].append(value) | |
# If kwargs_dict is empty, then all objects have been fetched and cached already. | |
if kwargs_dict: | |
# Only call with kwargs because we converted earlier | |
objects = self.cache_wrapper.inner_f( | |
**{key: list(set(params)) for key, params in kwargs_dict.items()} | |
) | |
# For the objects that were returned, reorder them such that they match | |
# the order they were primed in, and if nothing was returned for a set | |
# of arguments, use the provided default value | |
mapped_objects = function_tools.map_arguments_to_objects( | |
kwargs_dict, | |
objects, | |
self.cache_wrapper.object_key, | |
self.cache_wrapper.object_tuple_key, | |
self.cache_wrapper.argument_key, | |
self.cache_wrapper.result_key, | |
self.cache_wrapper.default_result, | |
) | |
# Add items that have been fetched already to the cache. | |
for kwargs, mapped_object in zip(unique_kwargs, mapped_objects): | |
cache_key = self.cache_wrapper.get_cache_key(**kwargs) | |
cache[cache_key] = mapped_object | |
for item in self.items: | |
if item.is_computed(): | |
continue | |
cache_key = self.cache_wrapper.get_cache_key(**item.kwargs) | |
item.set_value(cache[cache_key]) | |
class _MultigetBatchItem(asynq.BatchItemBase): | |
def __init__(self, batch, kwargs): | |
super().__init__(batch) | |
self.kwargs = kwargs | |
def async_cached(object_key, argument_key=None, default_result=None, | |
result_fields=None, join_table_name=None): | |
""" | |
:param object_key: the names of the attributes on the result object that are meant to match the function parameters | |
:param argument_key: the function parameter names you wish to match with the `object_key`s. | |
By default, this will be all of your wrapped function's arguments, in order. | |
So, you'd really only use this when you want to ignore a given function argument. | |
:param default_result: The result to put into the cache if nothing is matched. | |
:param result_fields: The attribute on your result object you wish to return the value of. | |
By default, the whole object is returned. | |
:param join_table_name: A temporary shortcut until we allow dot.path traversal for object_key. | |
Will call getattr(getattr(result, join_table_name), object_key) | |
:param coerce_args_to_strings: force coerce all arguments to the inner function to strings. | |
Useful for SQL where mixes of ints and strings in `WHERE x IN (list)` clauses causes poor performance. | |
:return: A wrapper that allows you to queue many O(1) calls and flush the queue all at once, | |
rather than executing the inner function body N times. | |
""" | |
def create_wrapper(inner_f): | |
return AsyncCacheWrapper( | |
inner_f, | |
object_key, | |
argument_key, | |
default_result, | |
result_fields, | |
join_table_name, | |
coerce_args_to_strings=False, | |
) | |
return create_wrapper | |
def clear_cache(): | |
multiget_cache.clear_cache() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
import asynq | |
from async_caching import async_cached, clear_cache | |
DataObject = namedtuple('DataObject', 'object_id') | |
@async_cached(object_key='object_id') | |
def gen_object(object_ids): | |
print("FLUSH:", object_ids) | |
return [ | |
DataObject(object_id) | |
for object_id in object_ids | |
] | |
@asynq.async() | |
def returns_dict(object_id): | |
return (yield { | |
'object_id': gen_object.async(object_id) | |
}) | |
@asynq.async() | |
def returns_array(object_id): | |
return (yield [ | |
gen_object.async(object_id) | |
]) | |
@asynq.async() | |
def root(data_generator): | |
object_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | |
return (yield [ | |
data_generator.async(object_id) | |
for object_id in object_ids | |
]) | |
if __name__ == '__main__': | |
print('### CORRECT BEHAVIOR ###') | |
root(returns_array) | |
# clear cache between runs | |
clear_cache() | |
print('!!! INCORRECT BEHAVIOR !!!') | |
root(returns_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Cython | |
asynq==0.1.4 | |
multiget-cache==0.0.10 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment