Skip to content

Instantly share code, notes, and snippets.

@xmfan
Created May 13, 2024 23:47
Show Gist options
  • Save xmfan/fd80d3b51f0021ccd292b62a6114f18d to your computer and use it in GitHub Desktop.
Save xmfan/fd80d3b51f0021ccd292b62a6114f18d to your computer and use it in GitHub Desktop.
#include "/tmp/tmp8wdgz8ol/rq/crq573iugmokkndxawm743sgoqnmhemtfiwhap5ducjuyma5rxco.h"
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
float* out_ptr0)
{
{
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(5L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(0L)];
auto tmp1 = in_ptr1[static_cast<long>(x0)];
auto tmp2 = decltype(tmp0)(tmp0 * tmp1);
out_ptr0[static_cast<long>(x0)] = tmp2;
}
}
}
// Python bindings to call kernel():
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <sstream>
#include <cstdlib>
#ifndef _MSC_VER
#if __cplusplus < 202002L
// C++20 earlier code
// https://en.cppreference.com/w/cpp/language/attributes/likely
#define likely(x) __builtin_expect(!!(x), 1)
#define unlikely(x) __builtin_expect(!!(x), 0)
#endif
#endif
// This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow.
// We manually link it below to workaround issues with fbcode build.
static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj);
template <typename T> static inline T parse_arg(PyObject* args, size_t n) {
static_assert(std::is_pointer<T>::value, "arg type must be pointer or long");
return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n)));
}
template <> inline long parse_arg<long>(PyObject* args, size_t n) {
auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n));
if(result == -1 && PyErr_Occurred())
[[unlikely]] throw std::runtime_error("expected int arg");
return result;
}
static PyObject* kernel_py(PyObject* self, PyObject* args) {
try {
if(!PyTuple_CheckExact(args))
[[unlikely]] throw std::runtime_error("tuple args required");
if(PyTuple_GET_SIZE(args) != 3)
[[unlikely]] throw std::runtime_error("requires 3 args");
kernel(parse_arg<float*>(args, 0), parse_arg<float*>(args, 1), parse_arg<float*>(args, 2));Py_RETURN_NONE;
} catch(std::exception const& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "unhandled error");
return nullptr;
}
}
static PyMethodDef py_methods[] = {
{"kernel", kernel_py, METH_VARARGS, ""},
{NULL, NULL, 0, NULL}};
static struct PyModuleDef py_module =
{PyModuleDef_HEAD_INIT, "kernel", NULL, -1, py_methods};
PyMODINIT_FUNC PyInit_kernel(void) {
const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR");
if(!str_addr) {
PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set");
return nullptr;
}
std::istringstream iss(str_addr);
uintptr_t addr = 0;
iss >> addr;
_torchinductor_pyobject_tensor_data_ptr =
reinterpret_cast<decltype(_torchinductor_pyobject_tensor_data_ptr)>(addr);
return PyModule_Create(&py_module);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment