Skip to content

Instantly share code, notes, and snippets.

@kainino0x
Last active April 26, 2024 21:46
Show Gist options
  • Save kainino0x/4f04609d91111b360cb4be93822b2f68 to your computer and use it in GitHub Desktop.
Save kainino0x/4f04609d91111b360cb4be93822b2f68 to your computer and use it in GitHub Desktop.
webgpu.h second userdata / C++ API proof of concept
// -std=c++11
// -Weverything -Wno-unused-variable -Wno-unused-parameter -Wno-c++98-compat -Wno-c++98-compat-pedantic
#include <cstdio>
#include <cstring>
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include <cstddef>
#include <cassert>
#include <x86intrin.h>
// C API
extern "C" {
typedef struct WGPUSomethingImpl* WGPUSomething;
typedef void(*WGPUCallback)(WGPUSomething, void* userdata);
typedef void(*WGPUCallback2)(WGPUSomething, void* userdata1, void* userdata2);
typedef void(*WGPUCallbackV)(WGPUSomething, size_t userdataSize, void* userdataPtr);
void wgpuCallCallback(WGPUSomething, WGPUCallback callback, void* userdata);
void wgpuCallCallback2(WGPUSomething, WGPUCallback2 callback, void* userdata1, void* userdata2);
void wgpuCallCallbackV(WGPUSomething, WGPUCallbackV, size_t userdataSize, void* userdataPtr, size_t userdataAlignment);
}
// C implementation
void wgpuCallCallback(WGPUSomething something, WGPUCallback callback, void* userdata) {
callback(something, userdata);
}
void wgpuCallCallback2(WGPUSomething something, WGPUCallback2 callback, void* userdata1, void* userdata2) {
callback(something, userdata1, userdata2);
}
// Guarantees the userdataPtr will be aligned to max_align_t
void wgpuCallCallbackV(WGPUSomething something, WGPUCallbackV callback, size_t userdataSize, void* userdataPtr, size_t userdataAlignment) {
// in a real implementation here we would copy the userdata somewhere along with the callback pointer
std::unique_ptr<char[]> userdataCopy{
new(std::align_val_t{std::max(alignof(max_align_t), userdataAlignment)}) char[userdataSize]
};
memcpy(userdataCopy.get(), userdataPtr, userdataSize);
callback(something, userdataSize, userdataCopy.get());
}
// C++ API
namespace wgpu {
using Something = struct SomethingImpl*;
using PlainCallback = void (Something something, void* userdata);
using PlainCallbackV = void (Something something, size_t userdataSize, void* userdata);
using LambdaCallback = void (Something something);
template<typename T>
using TemplatedCallback = void (Something something, T* userdata);
void CallCallback0(Something, WGPUCallback callback, void* userdata);
void CallCallback(Something, PlainCallback* callback, void* userdata);
void CallCallback2(Something, PlainCallback* callback, void* userdata);
void CallCallbackV(Something something, PlainCallbackV* callback, size_t userdataSize, void* userdataPtr);
void CallCallback_L(Something, std::function<LambdaCallback> lambda);
template<typename T>
void CallCallback_LT(Something something, T lambda);
template<typename T>
void CallCallback2_T(Something something, TemplatedCallback<T>* callback, T* userdata);
void CallCallbackV_L(Something, std::function<LambdaCallback> lambda);
template<typename T>
void CallCallbackV_LT(Something something, T lambda);
}
// C++ implementation
namespace wgpu {
namespace {
struct CallbackInfo {
PlainCallback* callback;
void* userdata;
};
void wrapCallback(WGPUSomething something, void* v_info) {
auto* info = reinterpret_cast<CallbackInfo*>(v_info);
info->callback(reinterpret_cast<Something>(something), info->userdata);
delete info;
}
void wrapCallback2(WGPUSomething something, void* v_callback, void* v_userdata) {
auto* callback = reinterpret_cast<PlainCallback*>(v_callback);
(*callback)(reinterpret_cast<Something>(something), v_userdata);
}
struct CallbackInfoV {
PlainCallbackV* callback;
size_t userdataSize;
void* userdata;
};
void wrapCallbackV(WGPUSomething something, size_t, void* v_info) {
auto* info = reinterpret_cast<CallbackInfoV*>(v_info);
info->callback(reinterpret_cast<Something>(something), info->userdataSize, info->userdata);
delete info;
}
void wrapCallback_L(WGPUSomething something, void *v_lambda) {
auto* lambda = reinterpret_cast<std::function<LambdaCallback>*>(v_lambda);
(*lambda)(reinterpret_cast<Something>(something));
delete lambda;
}
void wrapCallbackV_L(WGPUSomething something, size_t, void *v_lambda) {
auto* lambda = reinterpret_cast<std::function<LambdaCallback>*>(v_lambda);
(*lambda)(reinterpret_cast<Something>(something));
}
}
namespace detail {
template<typename T>
void wrapCallback_LT(WGPUSomething something, void* v_lambda) {
auto* lambda = reinterpret_cast<T*>(v_lambda);
(*lambda)(reinterpret_cast<Something>(something));
delete lambda; // <- EDIT: forgot this before
}
template<typename T>
void wrapCallback2_T(WGPUSomething something, void* v_callback, void* v_userdata) {
auto* callback = reinterpret_cast<TemplatedCallback<T>*>(v_callback);
(*callback)(reinterpret_cast<Something>(something), static_cast<T*>(v_userdata));
}
template<typename T>
void wrapCallbackV_LT(WGPUSomething something, size_t userdataSize, void* v_lambda) {
auto* lambda = reinterpret_cast<T*>(v_lambda);
(*lambda)(reinterpret_cast<Something>(something));
}
}
void CallCallback0(Something something, WGPUCallback callback, void* userdata) {
wgpuCallCallback(reinterpret_cast<WGPUSomething>(something), callback, userdata);
}
void CallCallback(Something something, PlainCallback* callback, void* userdata) {
auto* info = new CallbackInfo{callback, userdata};
wgpuCallCallback(reinterpret_cast<WGPUSomething>(something), wrapCallback, info);
}
void CallCallback2(Something something, PlainCallback* callback, void* userdata) {
wgpuCallCallback2(reinterpret_cast<WGPUSomething>(something), wrapCallback2, reinterpret_cast<void*>(callback), userdata);
}
void CallCallbackV(Something something, PlainCallbackV* callback, size_t userdataSize, void* userdataPtr) {
wgpuCallCallbackV(reinterpret_cast<WGPUSomething>(something), wrapCallbackV, userdataSize, userdataPtr, alignof(max_align_t));
}
void CallCallback_L(Something something, std::function<LambdaCallback> lambda) {
auto* l = new std::function<LambdaCallback>(lambda);
wgpuCallCallback(reinterpret_cast<WGPUSomething>(something), wrapCallback_L, l);
}
template<typename T>
void CallCallback_LT(Something something, T lambda) {
auto* l = new T(lambda);
wgpuCallCallback(reinterpret_cast<WGPUSomething>(something), &detail::wrapCallback_LT<T>, reinterpret_cast<void*>(l));
}
template<typename T>
void CallCallback2_T(Something something, TemplatedCallback<T>* callback, T* userdata) {
wgpuCallCallback2(reinterpret_cast<WGPUSomething>(something), &detail::wrapCallback2_T<T>, reinterpret_cast<void*>(callback), static_cast<void*>(userdata));
}
void CallCallbackV_L(Something something, std::function<LambdaCallback> lambda) {
// TODO: this is probably totally invalid if not is_trivially_copyable
static_assert(std::is_trivially_copyable<decltype(lambda)>::value, "lambda type must be trivially copyable");
wgpuCallCallbackV(reinterpret_cast<WGPUSomething>(something), wrapCallbackV_L, sizeof(lambda), reinterpret_cast<void*>(&lambda), alignof(decltype(lambda)));
}
template<typename T>
void CallCallbackV_LT(Something something, T lambda) {
// TODO: this is probably totally invalid if not is_trivially_copyable
static_assert(std::is_trivially_copyable<T>::value, "lambda type must be trivially copyable");
wgpuCallCallbackV(reinterpret_cast<WGPUSomething>(something), &detail::wrapCallbackV_LT<T>, sizeof(lambda), reinterpret_cast<void*>(&lambda), alignof(T));
}
}
int main() {
std::string captured = "captured";
static_assert(!std::is_trivially_copyable<decltype(captured)>::value, "test that we can captured non-trivially-copyable types");
wgpu::Something something = nullptr;
// Plain callback API, but 'something' uses the C type (inconsistent)
wgpu::CallCallback0(something, [](WGPUSomething s, void* v_userdata) {
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), reinterpret_cast<std::string*>(v_userdata)->c_str());
}, &captured);
// Plain callback API, with 'something' using the C++ type (consistent)
wgpu::CallCallback(something, [](wgpu::Something s, void* v_userdata) {
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), reinterpret_cast<std::string*>(v_userdata)->c_str());
}, &captured);
// Plain callback API implemented without new/delete
wgpu::CallCallback2(something, [](wgpu::Something s, void* v_userdata) {
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), reinterpret_cast<std::string*>(v_userdata)->c_str());
}, &captured);
// Lambda API, using a `new std::function`
wgpu::CallCallback_L(something, [&](wgpu::Something s) {
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), captured.c_str());
});
// Lambda API, using a template to allocate only once (`new T`, no allocation inside std::function)
wgpu::CallCallback_LT(something, [&](wgpu::Something s) {
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), captured.c_str());
});
// (I don't think a lambda API is possible to implement without new/delete)
// Plain callback API, without new/delete, with strongly-typed userdata
wgpu::CallCallback2_T<std::string>(something, [](wgpu::Something s, std::string* arg) {
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), arg->c_str());
}, &captured);
// Lambda API, with a variable-sized userdata and std::function
wgpu::CallCallbackV_L(something, [&](wgpu::Something s) {
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), captured.c_str());
});
// Lambda API, with a variable-sized userdata and templating to avoid std::function
// - Capture by reference
wgpu::CallCallbackV_LT(something, [&](wgpu::Something s) {
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), captured.c_str());
});
// - Capture by value
wgpu::CallCallbackV_LT(something, [=](wgpu::Something s) {
printf("hello %d: %p %s\n", __LINE__, static_cast<void*>(s), captured.c_str());
});
// Test with lambda with an alignment greater than max_align_t
__m256d capturedSIMD = {3, 4, 5, 6};
auto lambdaSIMD = [=](wgpu::Something s) {
size_t actualAlignmentOffset = reinterpret_cast<uintptr_t>(&capturedSIMD) % alignof(decltype(capturedSIMD));
printf("hello %d: %p [%f, ...] (alignment offset %zu should be 0)\n", __LINE__,
static_cast<void*>(s), capturedSIMD[0], actualAlignmentOffset);
assert(actualAlignmentOffset == 0);
};
static_assert(alignof(decltype(lambdaSIMD)) > alignof(max_align_t));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment