Skip to content

Instantly share code, notes, and snippets.

@ngoldbaum
Created October 29, 2019 22:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ngoldbaum/ebca6425f766a4faf43251a6fa7f228f to your computer and use it in GitHub Desktop.
Save ngoldbaum/ebca6425f766a4faf43251a6fa7f228f to your computer and use it in GitHub Desktop.
diff --git a/test/test_overrides.py b/test/test_overrides.py
index c9821c45db..f44523063b 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -29,7 +29,7 @@ def implements_diagonal(torch_function):
"""
@functools.wraps(torch_function)
def decorator(func):
- HANDLED_FUNCTIONS_DIAGONAL[torch_function.__name__] = func
+ HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func
return func
return decorator
@@ -123,7 +123,7 @@ def implements_sub(torch_function):
"Register a torch function override for SubTensor"
@functools.wraps(torch_function)
def decorator(func):
- HANDLED_FUNCTIONS_SUB[torch_function.__name__] = func
+ HANDLED_FUNCTIONS_SUB[torch_function] = func
return func
return decorator
@@ -169,7 +169,7 @@ def implements_sub_diagonal(torch_function):
"Register a torch function override for SubDiagonalTensor"
@functools.wraps(torch_function)
def decorator(func):
- HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function.__name__] = func
+ HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func
return func
return decorator
@@ -205,7 +205,7 @@ def implements_tensor_like(torch_function):
"Register a torch function override for TensorLike"
@functools.wraps(torch_function)
def decorator(func):
- HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function.__name__] = func
+ HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func
return func
return decorator
@@ -705,6 +705,7 @@ class TensorLike(object):
kwargs = {}
if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:
+ breakpoint()
return NotImplemented
# In this case _torch_function_ should override TensorLike objects
return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 6dd14a1a45..0661f31de0 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -103,7 +103,9 @@ static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
if(r.has_torch_function()){
PyObject* torch_function = PyObject_FastGetAttrString(r.get_overloaded_arg(0), "__torch_function__");
- return PyObject_CallFunctionObjArgs(torch_function, PyUnicode_FromString(r.get_func_name().data()), args, kwargs, NULL);
+ PyObject* torch_module = PyImport_ImportModule("torch");
+ PyObject* torch_api_function = PyObject_FastGetAttrString(torch_module, const_cast<char*>(r.get_func_name().data()));
+ return PyObject_CallFunctionObjArgs(torch_function, torch_api_function, args, kwargs, NULL);
}
${declare_namedtuple_return_types}
${dispatch}
diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp
index 5b1af1a5c4..1ea2aabe29 100644
--- a/tools/autograd/templates/python_torch_functions.cpp
+++ b/tools/autograd/templates/python_torch_functions.cpp
@@ -417,6 +417,10 @@ static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject*
auto r = parser.parse(args, kwargs, parsed_args);
if(r.has_torch_function()){
PyObject* torch_function = PyObject_FastGetAttrString(r.get_overloaded_arg(0), "__torch_function__");
+ PyObject* torch_module = PyImport_ImportModule("torch");
+ PyObject* torch_api_function = PyObject_FastGetAttrString(torch_module, const_cast<char*>(r.get_func_name().data()));
+ return PyObject_CallFunctionObjArgs(torch_function, torch_api_function, args, kwargs, NULL);
+
return PyObject_CallFunctionObjArgs(torch_function, PyUnicode_FromString(r.get_func_name().data()), args, kwargs, NULL);
}
if (r.idx == 0) {
diff --git a/torch/_overrides.py b/torch/_overrides.py
index 17239247f0..2c218b1e96 100644
--- a/torch/_overrides.py
+++ b/torch/_overrides.py
@@ -127,7 +127,7 @@ def _implement_torch_function(
# Use `public_api` instead of `implemenation` so __torch_function__
# implementations can do equality/identity comparisons.
result = overloaded_arg.__torch_function__(
- public_api.__name__, args, kwargs)
+ public_api, args, kwargs)
if result is not NotImplemented:
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment