Created
March 22, 2022 18:19
-
-
Save c200chromebook/1ee304161a39b247e8a029ff9d2b3cd7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numba.cuda as cuda | |
import numba as nb | |
import numpy as np | |
from numba.core import sigutils | |
from numba.cuda.compiler import _Kernel | |
from numba.cuda.codegen import CUDACodeLibrary | |
from numba.cuda.compiler import Dispatcher | |
import pickle | |
def _reduce_states(self): | |
""" | |
Reduce the instance for serialization. | |
Compiled definitions are serialized in PTX form. | |
Type annotation are discarded. | |
Thread, block and shared memory configuration are serialized. | |
Stream information is discarded. | |
""" | |
return dict(cooperative=self.cooperative, name=self.entry_name, | |
signature=self.signature, codelibrary=self._codelibrary, | |
debug=self.debug, lineinfo=self.lineinfo, | |
call_helper=self.call_helper, extensions=self.extensions) | |
@classmethod | |
def _rebuild(cls, cooperative, name, signature, codelibrary, debug, lineinfo, call_helper, extensions): | |
""" | |
Rebuild an instance. | |
""" | |
instance = cls.__new__(cls) | |
# invoke parent constructor | |
super(cls, instance).__init__() | |
# populate members | |
instance.cooperative = cooperative | |
instance.entry_name = name | |
instance.signature = signature | |
instance._type_annotation = None | |
instance._codelibrary = codelibrary | |
instance.debug = debug | |
instance.lineinfo = lineinfo | |
instance.call_helper = call_helper | |
instance.extensions = extensions | |
return instance | |
def _reduce_states_codelib(self): | |
""" | |
Reduce the instance for serialization. We retain the PTX and cubins, | |
but loaded functions are discarded. They are recreated when needed | |
after deserialization. | |
""" | |
if self._linking_files: | |
msg = ('cannot pickle CUDACodeLibrary function with additional ' | |
'libraries to link against') | |
raise RuntimeError(msg) | |
return dict( | |
codegen=None, | |
name=self.name, | |
entry_name=self._entry_name, | |
module=None, | |
linking_libraries=self._linking_libraries, | |
ptx_cache=self._ptx_cache, | |
cubin_cache={}, # self._cubin_cache | |
linkerinfo_cache={}, | |
max_registers=self._max_registers, | |
nvvm_options=self._nvvm_options | |
) | |
@classmethod | |
def _rebuild_codelib(cls, codegen, name, entry_name, module, linking_libraries, | |
ptx_cache, cubin_cache, linkerinfo_cache, max_registers, | |
nvvm_options): | |
""" | |
Rebuild an instance. | |
""" | |
instance = cls.__new__(cls) | |
super(cls, instance).__init__(codegen, name) | |
instance._entry_name = entry_name | |
instance._module = module | |
instance._linking_libraries = linking_libraries | |
instance._linking_files = set() | |
instance._ptx_cache = ptx_cache | |
instance._cubin_cache = cubin_cache | |
instance._linkerinfo_cache = linkerinfo_cache | |
instance._cufunc_cache = {} | |
instance._max_registers = max_registers | |
instance._nvvm_options = nvvm_options | |
return instance | |
@classmethod | |
def _rebuild_dispatcher(cls, py_func, sigs, targetoptions, overloads): | |
""" | |
Rebuild an instance. | |
""" | |
kernel = list(overloads.values())[0] | |
sig = sigs[0] | |
instance = cls(py_func, None, targetoptions) | |
argtypes, return_type = sigutils.normalize_signature(sig) | |
c_sig = [a._code for a in argtypes] | |
instance._insert(c_sig, kernel, cuda=True) | |
instance.overloads[argtypes] = kernel | |
kernel.bind() | |
instance.sigs.append(sig) | |
instance._can_compile = False | |
return instance | |
def _reduce_states_dispatcher(self): | |
""" | |
Reduce the instance for serialization. | |
Compiled definitions are PRESERVED. | |
""" | |
return dict(py_func=self.py_func, sigs=self.sigs, | |
targetoptions=self.targetoptions, overloads=self.overloads) | |
_Kernel._reduce_states = _reduce_states | |
_Kernel._rebuild = _rebuild | |
CUDACodeLibrary._reduce_states = _reduce_states_codelib | |
CUDACodeLibrary._rebuild = _rebuild_codelib | |
Dispatcher._reduce_states = _reduce_states_dispatcher | |
Dispatcher._rebuild = _rebuild_dispatcher | |
@cuda.jit(nb.void(nb.int32[:])) | |
def sample_func(arg): | |
arg[0] += 1 | |
kern = pickle.dumps(sample_func) | |
func2 = pickle.loads(kern) | |
func2[1, 1](np.array([1, 2, 3])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment