This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/test/test_overrides.py b/test/test_overrides.py | |
index c9821c45db..f44523063b 100644 | |
--- a/test/test_overrides.py | |
+++ b/test/test_overrides.py | |
@@ -29,7 +29,7 @@ def implements_diagonal(torch_function): | |
""" | |
@functools.wraps(torch_function) | |
def decorator(func): | |
- HANDLED_FUNCTIONS_DIAGONAL[torch_function.__name__] = func | |
+ HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func | |
return func | |
return decorator | |
@@ -123,7 +123,7 @@ def implements_sub(torch_function): | |
"Register a torch function override for SubTensor" | |
@functools.wraps(torch_function) | |
def decorator(func): | |
- HANDLED_FUNCTIONS_SUB[torch_function.__name__] = func | |
+ HANDLED_FUNCTIONS_SUB[torch_function] = func | |
return func | |
return decorator | |
@@ -169,7 +169,7 @@ def implements_sub_diagonal(torch_function): | |
"Register a torch function override for SubDiagonalTensor" | |
@functools.wraps(torch_function) | |
def decorator(func): | |
- HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function.__name__] = func | |
+ HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func | |
return func | |
return decorator | |
@@ -205,7 +205,7 @@ def implements_tensor_like(torch_function): | |
"Register a torch function override for TensorLike" | |
@functools.wraps(torch_function) | |
def decorator(func): | |
- HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function.__name__] = func | |
+ HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func | |
return func | |
return decorator | |
@@ -705,6 +705,7 @@ class TensorLike(object): | |
kwargs = {} | |
if func not in HANDLED_FUNCTIONS_TENSOR_LIKE: | |
+ breakpoint() | |
return NotImplemented | |
# In this case _torch_function_ should override TensorLike objects | |
return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs) | |
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py | |
index 6dd14a1a45..0661f31de0 100644 | |
--- a/tools/autograd/gen_python_functions.py | |
+++ b/tools/autograd/gen_python_functions.py | |
@@ -103,7 +103,9 @@ static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) | |
if(r.has_torch_function()){ | |
PyObject* torch_function = PyObject_FastGetAttrString(r.get_overloaded_arg(0), "__torch_function__"); | |
- return PyObject_CallFunctionObjArgs(torch_function, PyUnicode_FromString(r.get_func_name().data()), args, kwargs, NULL); | |
+ PyObject* torch_module = PyImport_ImportModule("torch"); | |
+ PyObject* torch_api_function = PyObject_FastGetAttrString(torch_module, const_cast<char*>(r.get_func_name().data())); | |
+ return PyObject_CallFunctionObjArgs(torch_function, torch_api_function, args, kwargs, NULL); | |
} | |
${declare_namedtuple_return_types} | |
${dispatch} | |
diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp | |
index 5b1af1a5c4..1ea2aabe29 100644 | |
--- a/tools/autograd/templates/python_torch_functions.cpp | |
+++ b/tools/autograd/templates/python_torch_functions.cpp | |
@@ -417,6 +417,10 @@ static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* | |
auto r = parser.parse(args, kwargs, parsed_args); | |
if(r.has_torch_function()){ | |
PyObject* torch_function = PyObject_FastGetAttrString(r.get_overloaded_arg(0), "__torch_function__"); | |
+ PyObject* torch_module = PyImport_ImportModule("torch"); | |
+ PyObject* torch_api_function = PyObject_FastGetAttrString(torch_module, const_cast<char*>(r.get_func_name().data())); | |
+ return PyObject_CallFunctionObjArgs(torch_function, torch_api_function, args, kwargs, NULL); | |
+ | |
return PyObject_CallFunctionObjArgs(torch_function, PyUnicode_FromString(r.get_func_name().data()), args, kwargs, NULL); | |
} | |
if (r.idx == 0) { | |
diff --git a/torch/_overrides.py b/torch/_overrides.py | |
index 17239247f0..2c218b1e96 100644 | |
--- a/torch/_overrides.py | |
+++ b/torch/_overrides.py | |
@@ -127,7 +127,7 @@ def _implement_torch_function( | |
# Use `public_api` instead of `implemenation` so __torch_function__ | |
# implementations can do equality/identity comparisons. | |
result = overloaded_arg.__torch_function__( | |
- public_api.__name__, args, kwargs) | |
+ public_api, args, kwargs) | |
if result is not NotImplemented: | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment