Skip to content

Instantly share code, notes, and snippets.

@sun1638650145
Last active June 27, 2024 04:20
Show Gist options
  • Save sun1638650145/8685a94cce248aef72d5b3d03b7b40bc to your computer and use it in GitHub Desktop.
Save sun1638650145/8685a94cce248aef72d5b3d03b7b40bc to your computer and use it in GitHub Desktop.
/*
[Description of PR](https://github.com/pybind/pybind11/pull/3544)
It is not sure whether the PR will be merged,
so if you want to use this feature, you can directly use this patch.
The patch corresponds to pybind11 2.9.x ~ 2.13.1.
To avoid redefinition of `struct`:
`npy_api` -> `npy_api_patch`
`is_complex` -> `is_complex_patch`
`npy_format_descriptor_name` -> `npy_format_descriptor_name_patch`
Copyright (c) 2021-2024 Steve Sun <s1638650145@gmail.com>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/complex.h"
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
template <typename>
struct numpy_scalar; // Forward declaration
PYBIND11_NAMESPACE_BEGIN(detail)
PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name) {
module_ numpy = module_::import("numpy");
str version_string = numpy.attr("__version__");
module_ numpy_lib = module_::import("numpy.lib");
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
int major_version = numpy_version.attr("major").cast<int>();
#ifdef PYBIND11_NUMPY_1_ONLY
if (major_version >= 2) {
throw std::runtime_error(
"This extension was built with PYBIND11_NUMPY_1_ONLY defined, "
"but NumPy 2 is used in this process. For NumPy2 compatibility, "
"this extension needs to be rebuilt without the PYBIND11_NUMPY_1_ONLY define.");
}
#endif
/* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
became a private module. */
std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
return module_::import((numpy_core_path + "." + submodule_name).c_str());
}
template <std::size_t>
constexpr int platform_lookup() {
return -1;
}
// Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
template <std::size_t size, typename T, typename... Ts, typename... Ints>
constexpr int platform_lookup(int I, Ints... Is) {
return sizeof(size) == sizeof(T) ? I : platform_lookup<size, Ts...>(Is...);
}
struct npy_api_patch {
enum constants {
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
NPY_ARRAY_OWNDATA_ = 0x0004,
NPY_ARRAY_FORCECAST_ = 0x0010,
NPY_ARRAY_ENSUREARRAY_ = 0x0040,
NPY_ARRAY_ALIGNED_ = 0x0100,
NPY_ARRAY_WRITEABLE_ = 0x0400,
NPY_BOOL_ = 0,
NPY_BYTE_,
NPY_UBYTE_,
NPY_SHORT_,
NPY_USHORT_,
NPY_INT_,
NPY_UINT_,
NPY_LONG_,
NPY_ULONG_,
NPY_LONGLONG_,
NPY_ULONGLONG_,
NPY_FLOAT_,
NPY_DOUBLE_,
NPY_LONGDOUBLE_,
NPY_CFLOAT_,
NPY_CDOUBLE_,
NPY_CLONGDOUBLE_,
NPY_OBJECT_ = 17,
NPY_STRING_,
NPY_UNICODE_,
NPY_VOID_,
// Platform-dependent normalization
NPY_INT8_ = NPY_BYTE_,
NPY_UINT8_ = NPY_UBYTE_,
NPY_INT16_ = NPY_SHORT_,
NPY_UINT16_ = NPY_USHORT_,
// `npy_common.h` defines the integer aliases. In order, it checks:
// NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
// and assigns the alias to the first matching size, so we should check in this order.
NPY_INT32_ = platform_lookup<4, long, int, short>(NPY_LONG_, NPY_INT_, NPY_SHORT_),
NPY_UINT32_ = platform_lookup<4, unsigned long, unsigned int, unsigned short>(
NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
NPY_INT64_ = platform_lookup<8, long, long long, int>(NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
NPY_UINT64_ = platform_lookup<8, unsigned long, unsigned long long, unsigned int>(
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
NPY_FLOAT32_
= platform_lookup<4, double, float, long double>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_FLOAT64_
= platform_lookup<8, double, float, long double>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_COMPLEX64_
= platform_lookup<8, std::complex<double>, std::complex<float>, std::complex<long double>>(
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_COMPLEX128_
= platform_lookup<8, std::complex<double>, std::complex<float>, std::complex<long double>>(
NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
NPY_CHAR_ = std::is_signed<char>::value ? NPY_BYTE_ : NPY_UBYTE_,
};
unsigned int PyArray_RUNTIME_VERSION_;
struct PyArray_Dims {
Py_intptr_t *ptr;
int len;
};
static npy_api_patch &get() {
PYBIND11_CONSTINIT static gil_safe_call_once_and_store<npy_api_patch> storage;
return storage.call_once_and_store_result(lookup).get_stored();
}
bool PyArray_Check_(PyObject *obj) const {
return PyObject_TypeCheck(obj, PyArray_Type_) != 0;
}
bool PyArrayDescr_Check_(PyObject *obj) const {
return PyObject_TypeCheck(obj, PyArrayDescr_Type_) != 0;
}
unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
PyObject *(*PyArray_DescrFromType_)(int);
PyObject *(*PyArray_TypeObjectFromType_)(int);
PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *,
PyObject *,
int,
Py_intptr_t const *,
Py_intptr_t const *,
void *,
int,
PyObject *);
// Unused. Not removed because that affects ABI of the class.
PyObject *(*PyArray_DescrNewFromType_)(int);
int (*PyArray_CopyInto_)(PyObject *, PyObject *);
PyObject *(*PyArray_NewCopy_)(PyObject *, int);
PyTypeObject *PyArray_Type_;
PyTypeObject *PyVoidArrType_Type_;
PyTypeObject *PyArrayDescr_Type_;
PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *);
void (*PyArray_ScalarAsCtype_)(PyObject *, void *);
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
#ifdef PYBIND11_NUMPY_1_ONLY
int (*PyArray_GetArrayParamsFromObject_)(PyObject *,
PyObject *,
unsigned char,
PyObject **,
int *,
Py_intptr_t *,
PyObject **,
PyObject *);
#endif
PyObject *(*PyArray_Squeeze_)(PyObject *);
// Unused. Not removed because that affects ABI of the class.
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
PyObject *(*PyArray_Resize_)(PyObject *, PyArray_Dims *, int, int);
PyObject *(*PyArray_Newshape_)(PyObject *, PyArray_Dims *, int);
PyObject *(*PyArray_View_)(PyObject *, PyObject *, PyObject *);
private:
enum functions {
API_PyArray_GetNDArrayCFeatureVersion = 211,
API_PyArray_Type = 2,
API_PyArrayDescr_Type = 3,
API_PyVoidArrType_Type = 39,
API_PyArray_DescrFromType = 45,
API_PyArray_TypeObjectFromType = 46,
API_PyArray_DescrFromScalar = 57,
API_PyArray_Scalar = 60,
API_PyArray_ScalarAsCtype = 62,
API_PyArray_FromAny = 69,
API_PyArray_Resize = 80,
// CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
API_PyArray_CopyInto = 50,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrNewFromType = 96,
API_PyArray_Newshape = 135,
API_PyArray_Squeeze = 136,
API_PyArray_View = 137,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
#ifdef PYBIND11_NUMPY_1_ONLY
API_PyArray_GetArrayParamsFromObject = 278,
#endif
API_PyArray_SetBaseObject = 282
};
static npy_api_patch lookup() {
module_ m = detail::import_numpy_core_submodule("multiarray");
auto c = m.attr("_ARRAY_API");
void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), nullptr);
if (api_ptr == nullptr) {
raise_from(PyExc_SystemError, "FAILURE obtaining numpy _ARRAY_API pointer.");
throw error_already_set();
}
npy_api_patch api;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
api.PyArray_RUNTIME_VERSION_ = api.PyArray_GetNDArrayCFeatureVersion_();
if (api.PyArray_RUNTIME_VERSION_ < 0x7) {
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
}
DECL_NPY_API(PyArray_Type);
DECL_NPY_API(PyVoidArrType_Type);
DECL_NPY_API(PyArrayDescr_Type);
DECL_NPY_API(PyArray_DescrFromType);
DECL_NPY_API(PyArray_TypeObjectFromType);
DECL_NPY_API(PyArray_DescrFromScalar);
DECL_NPY_API(PyArray_Scalar);
DECL_NPY_API(PyArray_ScalarAsCtype);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_Resize);
DECL_NPY_API(PyArray_CopyInto);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_Newshape);
DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_View);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
#ifdef PYBIND11_NUMPY_1_ONLY
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#endif
DECL_NPY_API(PyArray_SetBaseObject);
#undef DECL_NPY_API
return api;
}
};
template <typename T>
struct is_complex_patch : std::false_type {};
template <typename T>
struct is_complex_patch<std::complex<T>> : std::true_type {};
template <typename T, typename = void>
struct npy_format_descriptor_name_patch;
template <typename T>
struct npy_format_descriptor_name_patch<T, enable_if_t<std::is_integral<T>::value>> {
static constexpr auto name = const_name<std::is_same<T, bool>::value>(
const_name("bool"),
const_name<std::is_signed<T>::value>("int", "uint") + const_name<sizeof(T) * 8>());
};
template <typename T>
struct npy_format_descriptor_name_patch<T, enable_if_t<std::is_floating_point<T>::value>> {
static constexpr auto name
= const_name < std::is_same<T, float>::value
|| std::is_same<T, double>::value
> (const_name("float") + const_name<sizeof(T) * 8>(), const_name("longdouble"));
};
template <typename T>
struct npy_format_descriptor_name_patch<T, enable_if_t<is_complex_patch<T>::value>> {
static constexpr auto name
= const_name < std::is_same<typename T::value_type, float>::value
|| std::is_same<typename T::value_type, double>::value
> (const_name("complex") + const_name<sizeof(typename T::value_type) * 16>(),
const_name("longcomplex"));
};
template <typename T>
struct numpy_scalar_info {};
#define DECL_NPY_SCALAR(ctype_, typenum_) \
template <> \
struct numpy_scalar_info<ctype_> { \
static constexpr auto name = npy_format_descriptor_name_patch<ctype_>::name; \
static constexpr int typenum = npy_api_patch::typenum_##_; \
}
// boolean type
DECL_NPY_SCALAR(bool, NPY_BOOL);
// character types
DECL_NPY_SCALAR(char, NPY_CHAR);
DECL_NPY_SCALAR(signed char, NPY_BYTE);
DECL_NPY_SCALAR(unsigned char, NPY_UBYTE);
// signed integer types
DECL_NPY_SCALAR(std::int16_t, NPY_SHORT);
DECL_NPY_SCALAR(std::int32_t, NPY_INT);
DECL_NPY_SCALAR(std::int64_t, NPY_LONG);
#if defined(__linux__)
DECL_NPY_SCALAR(long long, NPY_LONG);
#else
DECL_NPY_SCALAR(long, NPY_LONG);
#endif
// unsigned integer types
DECL_NPY_SCALAR(std::uint16_t, NPY_USHORT);
DECL_NPY_SCALAR(std::uint32_t, NPY_UINT);
DECL_NPY_SCALAR(std::uint64_t, NPY_ULONG);
#if defined(__linux__)
DECL_NPY_SCALAR(unsigned long long, NPY_ULONG);
#else
DECL_NPY_SCALAR(unsigned long, NPY_ULONG);
#endif
// floating point types
DECL_NPY_SCALAR(float, NPY_FLOAT);
DECL_NPY_SCALAR(double, NPY_DOUBLE);
DECL_NPY_SCALAR(long double, NPY_LONGDOUBLE);
// complex types
DECL_NPY_SCALAR(std::complex<float>, NPY_CFLOAT);
DECL_NPY_SCALAR(std::complex<double>, NPY_CDOUBLE);
DECL_NPY_SCALAR(std::complex<long double>, NPY_CLONGDOUBLE);
#undef DECL_NPY_SCALAR
template <typename T>
struct type_caster<numpy_scalar<T>> {
using value_type = T;
using type_info = numpy_scalar_info<T>;
PYBIND11_TYPE_CASTER(numpy_scalar<T>, type_info::name);
static handle &target_type() {
static handle tp = npy_api_patch::get().PyArray_TypeObjectFromType_(type_info::typenum);
return tp;
}
static handle &target_dtype() {
static handle tp = npy_api_patch::get().PyArray_DescrFromType_(type_info::typenum);
return tp;
}
bool load(handle src, bool) {
if (isinstance(src, target_type())) {
npy_api_patch::get().PyArray_ScalarAsCtype_(src.ptr(), &value.value);
return true;
}
return false;
}
static handle cast(numpy_scalar<T> src, return_value_policy, handle) {
return npy_api_patch::get().PyArray_Scalar_(&src.value, target_dtype().ptr(), nullptr);
}
};
PYBIND11_NAMESPACE_END(detail)
template <typename T>
struct numpy_scalar {
using value_type = T;
value_type value;
numpy_scalar() = default;
numpy_scalar(value_type value) : value(value) {}
operator value_type() { return value; }
numpy_scalar &operator=(value_type value) {
this->value = value;
return *this;
}
};
template <typename T>
numpy_scalar<T> make_scalar(T value) {
return numpy_scalar<T>(value);
}
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
@sun1638650145
Copy link
Author

Update pybind11 2.10.3.

@sun1638650145
Copy link
Author

Update pybind11 2.10.4.

@sun1638650145
Copy link
Author

Update pybind 2.11.

@sun1638650145
Copy link
Author

Update pybind 2.11.1.

@sun1638650145
Copy link
Author

Update pybind11 2.12.0.

@sun1638650145
Copy link
Author

Update pybind11 2.13.1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment