Skip to content

Instantly share code, notes, and snippets.

@ngoldbaum
Created March 10, 2020 16:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ngoldbaum/b44f7711e61d3adb3f801e62354f3767 to your computer and use it in GitHub Desktop.
Save ngoldbaum/b44f7711e61d3adb3f801e62354f3767 to your computer and use it in GitHub Desktop.
// _cat
static PyObject * THPVariable__cat(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"_cat(TensorList tensors, int64_t dim=0, *, Tensor out=None)",
}, /*traceable=*/true);
ParsedArgs<3> parsed_args;
auto _r = parser.parse(args, kwargs, parsed_args);
if(_r.has_torch_function()) {
return handle_torch_function(_r, args, kwargs, THPVariableFunctionsModule, "torch");
}
if (_r.isNone(2)) {
// aten::_cat(Tensor[] tensors, int dim=0) -> Tensor
auto dispatch__cat = [](TensorList tensors, int64_t dim) -> Tensor {
pybind11::gil_scoped_release no_gil;
return at::_cat(tensors, dim);
};
return wrap(dispatch__cat(_r.tensorlist(0), _r.toInt64(1)));
} else {
// aten::_cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
auto dispatch__cat_out = [](Tensor out, TensorList tensors, int64_t dim) -> Tensor {
pybind11::gil_scoped_release no_gil;
return at::_cat_out(out, tensors, dim);
};
return wrap(dispatch__cat_out(_r.tensor(2), _r.tensorlist(0), _r.toInt64(1)));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment