Created
March 10, 2020 16:24
-
-
Save ngoldbaum/b44f7711e61d3adb3f801e62354f3767 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// _cat | |
static PyObject * THPVariable__cat(PyObject* self_, PyObject* args, PyObject* kwargs) | |
{ | |
HANDLE_TH_ERRORS | |
static PythonArgParser parser({ | |
"_cat(TensorList tensors, int64_t dim=0, *, Tensor out=None)", | |
}, /*traceable=*/true); | |
ParsedArgs<3> parsed_args; | |
auto _r = parser.parse(args, kwargs, parsed_args); | |
if(_r.has_torch_function()) { | |
return handle_torch_function(_r, args, kwargs, THPVariableFunctionsModule, "torch"); | |
} | |
if (_r.isNone(2)) { | |
// aten::_cat(Tensor[] tensors, int dim=0) -> Tensor | |
auto dispatch__cat = [](TensorList tensors, int64_t dim) -> Tensor { | |
pybind11::gil_scoped_release no_gil; | |
return at::_cat(tensors, dim); | |
}; | |
return wrap(dispatch__cat(_r.tensorlist(0), _r.toInt64(1))); | |
} else { | |
// aten::_cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) | |
auto dispatch__cat_out = [](Tensor out, TensorList tensors, int64_t dim) -> Tensor { | |
pybind11::gil_scoped_release no_gil; | |
return at::_cat_out(out, tensors, dim); | |
}; | |
return wrap(dispatch__cat_out(_r.tensor(2), _r.tensorlist(0), _r.toInt64(1))); | |
} | |
Py_RETURN_NONE; | |
END_HANDLE_TH_ERRORS | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment