Skip to content

Instantly share code, notes, and snippets.

@beam2d
Created March 15, 2018 08:55
Show Gist options
  • Save beam2d/2217ca13735e056e10a3fdd8421a9e2f to your computer and use it in GitHub Desktop.
Save beam2d/2217ca13735e056e10a3fdd8421a9e2f to your computer and use it in GitHub Desktop.
import inspect
import cupy as cp
import numpy as np
def cupify(f):
cp_f = None
def wrapped(*args, **kw):
xp = cp.get_array_module(*args)
if xp is np:
return f(*args, **kw)
nonlocal cp_f
if cp_f is None:
source = inspect.getsource(f)
np_keys = [k for k, v in f.__globals__.items() if v is np]
cp_glb = dict(f.__globals__)
for k in np_keys:
cp_glb[k] = cp
exec(source, cp_glb)
cp_f = cp_glb[f.__name__]
return cp_f(*args, **kw)
return wrapped
def foo(x):
return np.exp(x) + np.log(x)
foo = cupify(foo) # cannot use as a decorator
print(foo(np.array([1, 2, 3])))
print(foo(cp.array([1, 2, 3])))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment