Skip to content

Instantly share code, notes, and snippets.

@LiraNuna
Last active March 18, 2017 06:55
Show Gist options
  • Save LiraNuna/a90b6ac3cc820e07022787ab07204f9a to your computer and use it in GitHub Desktop.
Save LiraNuna/a90b6ac3cc820e07022787ab07204f9a to your computer and use it in GitHub Desktop.
asynq dict fanout bug
from collections import defaultdict
import asynq
import multiget_cache
from multiget_cache import base_cache_wrapper
from multiget_cache import function_tools
class AsyncCacheWrapper(base_cache_wrapper.BaseCacheWrapper):
instances = []
def __init__(self, inner_f, object_key, argument_key=None, default_result=None,
result_key=None, object_tuple_key=None, coerce_args_to_strings=False):
self.object_key = object_key
self.result_key = result_key
self.default_result = default_result
self.object_tuple_key = object_tuple_key
self.coerce_args_to_strings = coerce_args_to_strings
self.argument_key = argument_key
if not argument_key:
self.argument_key = function_tools.get_kwargs_for_function(inner_f)
self.batch = _MultigetBatch(self)
super().__init__(inner_f)
AsyncCacheWrapper.instances.append(self)
def _make_batch_item(self, kwargs):
return _MultigetBatchItem(self.batch, kwargs)
def __call__(self, *args, **kwargs):
kwargs.update(function_tools.convert_args_to_kwargs(self.inner_f, args))
# We are forced to flush the batch to get our result. Before flushing, we must add the current parameters
# to the batch (using prime) in case it was not already primed.
self.prime(*args, **kwargs)
self.batch.flush()
# Since flushing saves to cache, we should be able to retrieve the result from there
cache = multiget_cache.get_cache()
cache_key = self.get_cache_key(**kwargs)
return cache[cache_key]
def prime(self, *args, **kwargs):
# For backwards compatibility with async_cached.
# Since `async` does what we want, and we want to get rid of `prime`, just call `async`.
self.batch.items.append(self.async(*args, **kwargs))
def async(self, *args, **kwargs):
kwargs.update(function_tools.convert_args_to_kwargs(self.inner_f, args))
return self._make_batch_item(kwargs)
class _MultigetBatch(asynq.BatchBase):
def __init__(self, cache_wrapper):
self.cache_wrapper = cache_wrapper
super().__init__()
def _try_switch_active_batch(self):
if self.cache_wrapper.batch is self:
self.cache_wrapper.batch = _MultigetBatch(self.cache_wrapper)
def _flush(self):
cache = multiget_cache.get_cache()
kwargs_dict = defaultdict(list)
unique_kwargs = []
# We only want kwargs that we will fetch, and aren't in the cache already.
# Since batching can fetch duplicate things, we deduplicate so we
# only fetch things we need that aren't in the cache.
for item in self.items:
cache_key = self.cache_wrapper.get_cache_key(**item.kwargs)
if cache_key in cache:
continue
# We can't hash dicts, we have to do this O(n) comparison.
if item.kwargs not in unique_kwargs:
unique_kwargs.append(item.kwargs)
# Convert batch items to arguments.
for kwargs in unique_kwargs:
for key, value in kwargs.items():
# sometimes we get a list??
if hasattr(value, 'append'):
value = value[0]
if self.cache_wrapper.coerce_args_to_strings:
value = str(value)
kwargs_dict[key].append(value)
# If kwargs_dict is empty, then all objects have been fetched and cached already.
if kwargs_dict:
# Only call with kwargs because we converted earlier
objects = self.cache_wrapper.inner_f(
**{key: list(set(params)) for key, params in kwargs_dict.items()}
)
# For the objects that were returned, reorder them such that they match
# the order they were primed in, and if nothing was returned for a set
# of arguments, use the provided default value
mapped_objects = function_tools.map_arguments_to_objects(
kwargs_dict,
objects,
self.cache_wrapper.object_key,
self.cache_wrapper.object_tuple_key,
self.cache_wrapper.argument_key,
self.cache_wrapper.result_key,
self.cache_wrapper.default_result,
)
# Add items that have been fetched already to the cache.
for kwargs, mapped_object in zip(unique_kwargs, mapped_objects):
cache_key = self.cache_wrapper.get_cache_key(**kwargs)
cache[cache_key] = mapped_object
for item in self.items:
if item.is_computed():
continue
cache_key = self.cache_wrapper.get_cache_key(**item.kwargs)
item.set_value(cache[cache_key])
class _MultigetBatchItem(asynq.BatchItemBase):
def __init__(self, batch, kwargs):
super().__init__(batch)
self.kwargs = kwargs
def async_cached(object_key, argument_key=None, default_result=None,
result_fields=None, join_table_name=None):
"""
:param object_key: the names of the attributes on the result object that are meant to match the function parameters
:param argument_key: the function parameter names you wish to match with the `object_key`s.
By default, this will be all of your wrapped function's arguments, in order.
So, you'd really only use this when you want to ignore a given function argument.
:param default_result: The result to put into the cache if nothing is matched.
:param result_fields: The attribute on your result object you wish to return the value of.
By default, the whole object is returned.
:param join_table_name: A temporary shortcut until we allow dot.path traversal for object_key.
Will call getattr(getattr(result, join_table_name), object_key)
:param coerce_args_to_strings: force coerce all arguments to the inner function to strings.
Useful for SQL where mixes of ints and strings in `WHERE x IN (list)` clauses causes poor performance.
:return: A wrapper that allows you to queue many O(1) calls and flush the queue all at once,
rather than executing the inner function body N times.
"""
def create_wrapper(inner_f):
return AsyncCacheWrapper(
inner_f,
object_key,
argument_key,
default_result,
result_fields,
join_table_name,
coerce_args_to_strings=False,
)
return create_wrapper
def clear_cache():
multiget_cache.clear_cache()
from collections import namedtuple
import asynq
from async_caching import async_cached, clear_cache
DataObject = namedtuple('DataObject', 'object_id')
@async_cached(object_key='object_id')
def gen_object(object_ids):
print("FLUSH:", object_ids)
return [
DataObject(object_id)
for object_id in object_ids
]
@asynq.async()
def returns_dict(object_id):
return (yield {
'object_id': gen_object.async(object_id)
})
@asynq.async()
def returns_array(object_id):
return (yield [
gen_object.async(object_id)
])
@asynq.async()
def root(data_generator):
object_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
return (yield [
data_generator.async(object_id)
for object_id in object_ids
])
if __name__ == '__main__':
print('### CORRECT BEHAVIOR ###')
root(returns_array)
# clear cache between runs
clear_cache()
print('!!! INCORRECT BEHAVIOR !!!')
root(returns_dict)
Cython
asynq==0.1.4
multiget-cache==0.0.10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment