Last active
April 26, 2024 21:46
-
-
Save kainino0x/4f04609d91111b360cb4be93822b2f68 to your computer and use it in GitHub Desktop.
webgpu.h second userdata / C++ API proof of concept
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// -std=c++11 | |
// -Weverything -Wno-unused-variable -Wno-unused-parameter -Wno-c++98-compat -Wno-c++98-compat-pedantic | |
#include <cstdio> | |
#include <cstring> | |
#include <functional> | |
#include <memory> | |
#include <string> | |
#include <vector> | |
#include <cstddef> | |
#include <cassert> | |
#include <x86intrin.h> | |
// C API | |
extern "C" { | |
typedef struct WGPUSomethingImpl* WGPUSomething; | |
typedef void(*WGPUCallback)(WGPUSomething, void* userdata); | |
typedef void(*WGPUCallback2)(WGPUSomething, void* userdata1, void* userdata2); | |
typedef void(*WGPUCallbackV)(WGPUSomething, size_t userdataSize, void* userdataPtr); | |
void wgpuCallCallback(WGPUSomething, WGPUCallback callback, void* userdata); | |
void wgpuCallCallback2(WGPUSomething, WGPUCallback2 callback, void* userdata1, void* userdata2); | |
void wgpuCallCallbackV(WGPUSomething, WGPUCallbackV, size_t userdataSize, void* userdataPtr, size_t userdataAlignment); | |
} | |
// C implementation | |
void wgpuCallCallback(WGPUSomething something, WGPUCallback callback, void* userdata) { | |
callback(something, userdata); | |
} | |
void wgpuCallCallback2(WGPUSomething something, WGPUCallback2 callback, void* userdata1, void* userdata2) { | |
callback(something, userdata1, userdata2); | |
} | |
// Guarantees the userdataPtr will be aligned to max_align_t | |
void wgpuCallCallbackV(WGPUSomething something, WGPUCallbackV callback, size_t userdataSize, void* userdataPtr, size_t userdataAlignment) { | |
// in a real implementation here we would copy the userdata somewhere along with the callback pointer | |
std::unique_ptr<char[]> userdataCopy{ | |
new(std::align_val_t{std::max(alignof(max_align_t), userdataAlignment)}) char[userdataSize] | |
}; | |
memcpy(userdataCopy.get(), userdataPtr, userdataSize); | |
callback(something, userdataSize, userdataCopy.get()); | |
} | |
// C++ API | |
namespace wgpu { | |
using Something = struct SomethingImpl*; | |
using PlainCallback = void (Something something, void* userdata); | |
using PlainCallbackV = void (Something something, size_t userdataSize, void* userdata); | |
using LambdaCallback = void (Something something); | |
template<typename T> | |
using TemplatedCallback = void (Something something, T* userdata); | |
void CallCallback0(Something, WGPUCallback callback, void* userdata); | |
void CallCallback(Something, PlainCallback* callback, void* userdata); | |
void CallCallback2(Something, PlainCallback* callback, void* userdata); | |
void CallCallbackV(Something something, PlainCallbackV* callback, size_t userdataSize, void* userdataPtr); | |
void CallCallback_L(Something, std::function<LambdaCallback> lambda); | |
template<typename T> | |
void CallCallback_LT(Something something, T lambda); | |
template<typename T> | |
void CallCallback2_T(Something something, TemplatedCallback<T>* callback, T* userdata); | |
void CallCallbackV_L(Something, std::function<LambdaCallback> lambda); | |
template<typename T> | |
void CallCallbackV_LT(Something something, T lambda); | |
} | |
// C++ implementation | |
namespace wgpu { | |
namespace { | |
struct CallbackInfo { | |
PlainCallback* callback; | |
void* userdata; | |
}; | |
void wrapCallback(WGPUSomething something, void* v_info) { | |
auto* info = reinterpret_cast<CallbackInfo*>(v_info); | |
info->callback(reinterpret_cast<Something>(something), info->userdata); | |
delete info; | |
} | |
void wrapCallback2(WGPUSomething something, void* v_callback, void* v_userdata) { | |
auto* callback = reinterpret_cast<PlainCallback*>(v_callback); | |
(*callback)(reinterpret_cast<Something>(something), v_userdata); | |
} | |
struct CallbackInfoV { | |
PlainCallbackV* callback; | |
size_t userdataSize; | |
void* userdata; | |
}; | |
void wrapCallbackV(WGPUSomething something, size_t, void* v_info) { | |
auto* info = reinterpret_cast<CallbackInfoV*>(v_info); | |
info->callback(reinterpret_cast<Something>(something), info->userdataSize, info->userdata); | |
delete info; | |
} | |
void wrapCallback_L(WGPUSomething something, void *v_lambda) { | |
auto* lambda = reinterpret_cast<std::function<LambdaCallback>*>(v_lambda); | |
(*lambda)(reinterpret_cast<Something>(something)); | |
delete lambda; | |
} | |
void wrapCallbackV_L(WGPUSomething something, size_t, void *v_lambda) { | |
auto* lambda = reinterpret_cast<std::function<LambdaCallback>*>(v_lambda); | |
(*lambda)(reinterpret_cast<Something>(something)); | |
} | |
} | |
namespace detail { | |
template<typename T> | |
void wrapCallback_LT(WGPUSomething something, void* v_lambda) { | |
auto* lambda = reinterpret_cast<T*>(v_lambda); | |
(*lambda)(reinterpret_cast<Something>(something)); | |
delete lambda; // <- EDIT: forgot this before | |
} | |
template<typename T> | |
void wrapCallback2_T(WGPUSomething something, void* v_callback, void* v_userdata) { | |
auto* callback = reinterpret_cast<TemplatedCallback<T>*>(v_callback); | |
(*callback)(reinterpret_cast<Something>(something), static_cast<T*>(v_userdata)); | |
} | |
template<typename T> | |
void wrapCallbackV_LT(WGPUSomething something, size_t userdataSize, void* v_lambda) { | |
auto* lambda = reinterpret_cast<T*>(v_lambda); | |
(*lambda)(reinterpret_cast<Something>(something)); | |
} | |
} | |
void CallCallback0(Something something, WGPUCallback callback, void* userdata) { | |
wgpuCallCallback(reinterpret_cast<WGPUSomething>(something), callback, userdata); | |
} | |
void CallCallback(Something something, PlainCallback* callback, void* userdata) { | |
auto* info = new CallbackInfo{callback, userdata}; | |
wgpuCallCallback(reinterpret_cast<WGPUSomething>(something), wrapCallback, info); | |
} | |
void CallCallback2(Something something, PlainCallback* callback, void* userdata) { | |
wgpuCallCallback2(reinterpret_cast<WGPUSomething>(something), wrapCallback2, reinterpret_cast<void*>(callback), userdata); | |
} | |
void CallCallbackV(Something something, PlainCallbackV* callback, size_t userdataSize, void* userdataPtr) { | |
wgpuCallCallbackV(reinterpret_cast<WGPUSomething>(something), wrapCallbackV, userdataSize, userdataPtr, alignof(max_align_t)); | |
} | |
void CallCallback_L(Something something, std::function<LambdaCallback> lambda) { | |
auto* l = new std::function<LambdaCallback>(lambda); | |
wgpuCallCallback(reinterpret_cast<WGPUSomething>(something), wrapCallback_L, l); | |
} | |
template<typename T> | |
void CallCallback_LT(Something something, T lambda) { | |
auto* l = new T(lambda); | |
wgpuCallCallback(reinterpret_cast<WGPUSomething>(something), &detail::wrapCallback_LT<T>, reinterpret_cast<void*>(l)); | |
} | |
template<typename T> | |
void CallCallback2_T(Something something, TemplatedCallback<T>* callback, T* userdata) { | |
wgpuCallCallback2(reinterpret_cast<WGPUSomething>(something), &detail::wrapCallback2_T<T>, reinterpret_cast<void*>(callback), static_cast<void*>(userdata)); | |
} | |
void CallCallbackV_L(Something something, std::function<LambdaCallback> lambda) { | |
// TODO: this is probably totally invalid if not is_trivially_copyable | |
static_assert(std::is_trivially_copyable<decltype(lambda)>::value, "lambda type must be trivially copyable"); | |
wgpuCallCallbackV(reinterpret_cast<WGPUSomething>(something), wrapCallbackV_L, sizeof(lambda), reinterpret_cast<void*>(&lambda), alignof(decltype(lambda))); | |
} | |
template<typename T> | |
void CallCallbackV_LT(Something something, T lambda) { | |
// TODO: this is probably totally invalid if not is_trivially_copyable | |
static_assert(std::is_trivially_copyable<T>::value, "lambda type must be trivially copyable"); | |
wgpuCallCallbackV(reinterpret_cast<WGPUSomething>(something), &detail::wrapCallbackV_LT<T>, sizeof(lambda), reinterpret_cast<void*>(&lambda), alignof(T)); | |
} | |
} | |
int main() { | |
std::string captured = "captured"; | |
static_assert(!std::is_trivially_copyable<decltype(captured)>::value, "test that we can captured non-trivially-copyable types"); | |
wgpu::Something something = nullptr; | |
// Plain callback API, but 'something' uses the C type (inconsistent) | |
wgpu::CallCallback0(something, [](WGPUSomething s, void* v_userdata) { | |
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), reinterpret_cast<std::string*>(v_userdata)->c_str()); | |
}, &captured); | |
// Plain callback API, with 'something' using the C++ type (consistent) | |
wgpu::CallCallback(something, [](wgpu::Something s, void* v_userdata) { | |
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), reinterpret_cast<std::string*>(v_userdata)->c_str()); | |
}, &captured); | |
// Plain callback API implemented without new/delete | |
wgpu::CallCallback2(something, [](wgpu::Something s, void* v_userdata) { | |
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), reinterpret_cast<std::string*>(v_userdata)->c_str()); | |
}, &captured); | |
// Lambda API, using a `new std::function` | |
wgpu::CallCallback_L(something, [&](wgpu::Something s) { | |
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), captured.c_str()); | |
}); | |
// Lambda API, using a template to allocate only once (`new T`, no allocation inside std::function) | |
wgpu::CallCallback_LT(something, [&](wgpu::Something s) { | |
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), captured.c_str()); | |
}); | |
// (I don't think a lambda API is possible to implement without new/delete) | |
// Plain callback API, without new/delete, with strongly-typed userdata | |
wgpu::CallCallback2_T<std::string>(something, [](wgpu::Something s, std::string* arg) { | |
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), arg->c_str()); | |
}, &captured); | |
// Lambda API, with a variable-sized userdata and std::function | |
wgpu::CallCallbackV_L(something, [&](wgpu::Something s) { | |
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), captured.c_str()); | |
}); | |
// Lambda API, with a variable-sized userdata and templating to avoid std::function | |
// - Capture by reference | |
wgpu::CallCallbackV_LT(something, [&](wgpu::Something s) { | |
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), captured.c_str()); | |
}); | |
// - Capture by value | |
wgpu::CallCallbackV_LT(something, [=](wgpu::Something s) { | |
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), captured.c_str()); | |
}); | |
// Test with lambda with an alignment greater than max_align_t | |
__m256d capturedSIMD = {3, 4, 5, 6}; | |
auto lambdaSIMD = [=](wgpu::Something s) { | |
size_t actualAlignmentOffset = reinterpret_cast<uintptr_t>(&capturedSIMD) % alignof(decltype(capturedSIMD)); | |
printf("hello %d: %p [%f, ...] (alignment offset %zu should be 0)\n", __LINE__, | |
static_cast<void*>(s), capturedSIMD[0], actualAlignmentOffset); | |
assert(actualAlignmentOffset == 0); | |
}; | |
static_assert(alignof(decltype(lambdaSIMD)) > alignof(max_align_t)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment