Created
August 31, 2023 13:38
-
-
Save dlee992/517b9fa1ff86d3b3c7524e85918e7b5a to your computer and use it in GitHub Desktop.
caching patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/numba/core/caching.py b/numba/core/caching.py | |
index a830524acea..1b7eac4ae86 100644 | |
--- a/numba/core/caching.py | |
+++ b/numba/core/caching.py | |
@@ -793,7 +793,7 @@ def make_library_cache(prefix): | |
# list of types for which the cache invalidation is extended to their | |
# definitions | |
-dep_types = (types.Dispatcher, types.Function) | |
+dep_types = (types.Dispatcher, types.Function, types.BoundFunction) | |
def get_function_dependencies(overload: OverloadData | |
@@ -843,7 +843,7 @@ def get_deps_info(fc_ty: pt.Union[types.Dispatcher, types.Function], sig | |
if isinstance(fc_ty, types.Dispatcher): | |
dispatcher = fc_ty.dispatcher | |
deps_stamps = dispatcher.cache_deps_info(sig) | |
- elif isinstance(fc_ty, types.Function): | |
+ elif isinstance(fc_ty, (types.Function, types.BoundFunction)): | |
if hasattr(fc_ty.typing_key, "_dispatcher"): | |
# this case captures DUFuncs and GUFuncs | |
dispatcher = fc_ty.key[0]._dispatcher | |
@@ -853,10 +853,17 @@ def get_deps_info(fc_ty: pt.Union[types.Dispatcher, types.Function], sig | |
# overload | |
# If the template does not have `get_cache_deps_info` it might be | |
# a generated class for a global value in Registry.register_global | |
- deps_stamps = [tmplt.get_cache_deps_info(tmplt, sig, get_function_dependencies) | |
- for tmplt in fc_ty.templates | |
- if hasattr(tmplt, 'get_cache_deps_info')] | |
- deps_stamps = {k: v for d in deps_stamps for k, v in d.items()} | |
+ if isinstance(fc_ty, types.Function): | |
+ deps_stamps = [tmplt.get_cache_deps_info(tmplt, sig, get_function_dependencies) | |
+ for tmplt in fc_ty.templates | |
+ if hasattr(tmplt, 'get_cache_deps_info')] | |
+ deps_stamps = {k: v for d in deps_stamps for k, v in d.items()} | |
+ elif isinstance(fc_ty, types.BoundFunction): | |
+ tmplt = fc_ty.template | |
+ deps_stamps = [] | |
+ if hasattr(tmplt, 'get_cache_deps_info'): | |
+ deps_stamps.append(tmplt.get_cache_deps_info(tmplt, sig, get_function_dependencies)) | |
+ deps_stamps = {k: v for d in deps_stamps for k, v in d.items()} | |
return deps_stamps | |
@@ -877,7 +884,7 @@ def get_impl_filenames(fc_ty: pt.Union[types.Dispatcher, types.Function] | |
dispatcher = fc_ty.dispatcher | |
py_func = dispatcher.py_func | |
py_files = [py_func.__code__.co_filename] | |
- elif isinstance(fc_ty, types.Function): | |
+ elif isinstance(fc_ty, (types.Function, types.BoundFunction)): | |
if hasattr(fc_ty.typing_key, "_dispatcher"): | |
# this case captures DUFuncs and GUFuncs | |
dispatcher = fc_ty.key[0]._dispatcher | |
@@ -886,8 +893,13 @@ def get_impl_filenames(fc_ty: pt.Union[types.Dispatcher, types.Function] | |
else: | |
# a type of Function with a dispatcher associated. Probably an | |
# overload | |
- py_files = [tmplt.get_template_info(tmplt)["filename"] | |
- for tmplt in fc_ty.templates] | |
+ if isinstance(fc_ty, types.Function): | |
+ py_files = [tmplt.get_template_info(tmplt)["filename"] | |
+ for tmplt in fc_ty.templates] | |
+ elif isinstance(fc_ty, types.BoundFunction): | |
+ tmplt = fc_ty.template | |
+ info = tmplt.get_template_info(tmplt) | |
+ py_files = [info["filename"], ] | |
# the base path depends on what tmplt.get_template_info is doing | |
# in this case, the filenames returned by get_template_info are |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment