Skip to content

Instantly share code, notes, and snippets.

@dlee992
Created August 31, 2023 13:38
Show Gist options
  • Save dlee992/517b9fa1ff86d3b3c7524e85918e7b5a to your computer and use it in GitHub Desktop.
Save dlee992/517b9fa1ff86d3b3c7524e85918e7b5a to your computer and use it in GitHub Desktop.
caching patch
diff --git a/numba/core/caching.py b/numba/core/caching.py
index a830524acea..1b7eac4ae86 100644
--- a/numba/core/caching.py
+++ b/numba/core/caching.py
@@ -793,7 +793,7 @@ def make_library_cache(prefix):
# list of types for which the cache invalidation is extended to their
# definitions
-dep_types = (types.Dispatcher, types.Function)
+dep_types = (types.Dispatcher, types.Function, types.BoundFunction)
def get_function_dependencies(overload: OverloadData
@@ -843,7 +843,7 @@ def get_deps_info(fc_ty: pt.Union[types.Dispatcher, types.Function], sig
if isinstance(fc_ty, types.Dispatcher):
dispatcher = fc_ty.dispatcher
deps_stamps = dispatcher.cache_deps_info(sig)
- elif isinstance(fc_ty, types.Function):
+ elif isinstance(fc_ty, (types.Function, types.BoundFunction)):
if hasattr(fc_ty.typing_key, "_dispatcher"):
# this case captures DUFuncs and GUFuncs
dispatcher = fc_ty.key[0]._dispatcher
@@ -853,10 +853,17 @@ def get_deps_info(fc_ty: pt.Union[types.Dispatcher, types.Function], sig
# overload
# If the template does not have `get_cache_deps_info` it might be
# a generated class for a global value in Registry.register_global
- deps_stamps = [tmplt.get_cache_deps_info(tmplt, sig, get_function_dependencies)
- for tmplt in fc_ty.templates
- if hasattr(tmplt, 'get_cache_deps_info')]
- deps_stamps = {k: v for d in deps_stamps for k, v in d.items()}
+ if isinstance(fc_ty, types.Function):
+ deps_stamps = [tmplt.get_cache_deps_info(tmplt, sig, get_function_dependencies)
+ for tmplt in fc_ty.templates
+ if hasattr(tmplt, 'get_cache_deps_info')]
+ deps_stamps = {k: v for d in deps_stamps for k, v in d.items()}
+ elif isinstance(fc_ty, types.BoundFunction):
+ tmplt = fc_ty.template
+ deps_stamps = []
+ if hasattr(tmplt, 'get_cache_deps_info'):
+ deps_stamps.append(tmplt.get_cache_deps_info(tmplt, sig, get_function_dependencies))
+ deps_stamps = {k: v for d in deps_stamps for k, v in d.items()}
return deps_stamps
@@ -877,7 +884,7 @@ def get_impl_filenames(fc_ty: pt.Union[types.Dispatcher, types.Function]
dispatcher = fc_ty.dispatcher
py_func = dispatcher.py_func
py_files = [py_func.__code__.co_filename]
- elif isinstance(fc_ty, types.Function):
+ elif isinstance(fc_ty, (types.Function, types.BoundFunction)):
if hasattr(fc_ty.typing_key, "_dispatcher"):
# this case captures DUFuncs and GUFuncs
dispatcher = fc_ty.key[0]._dispatcher
@@ -886,8 +893,13 @@ def get_impl_filenames(fc_ty: pt.Union[types.Dispatcher, types.Function]
else:
# a type of Function with a dispatcher associated. Probably an
# overload
- py_files = [tmplt.get_template_info(tmplt)["filename"]
- for tmplt in fc_ty.templates]
+ if isinstance(fc_ty, types.Function):
+ py_files = [tmplt.get_template_info(tmplt)["filename"]
+ for tmplt in fc_ty.templates]
+ elif isinstance(fc_ty, types.BoundFunction):
+ tmplt = fc_ty.template
+ info = tmplt.get_template_info(tmplt)
+ py_files = [info["filename"], ]
# the base path depends on what tmplt.get_template_info is doing
# in this case, the filenames returned by get_template_info are
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment