Skip to content

Instantly share code, notes, and snippets.

@nullhook
Last active February 16, 2024 16:43
Show Gist options
  • Save nullhook/11d74c02dc42e061ade9528973fae7f4 to your computer and use it in GitHub Desktop.
Save nullhook/11d74c02dc42e061ade9528973fae7f4 to your computer and use it in GitHub Desktop.
compute in metal
#include <iostream>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "Metal.hpp"
MTL::Buffer* outputs;
MTL::Buffer* input0;
int main() {
// both represents a metal context
MTL::Device* device = MTL::CreateSystemDefaultDevice();
MTL::CommandQueue* command_queue = device->newCommandQueue();
MTL::Library* library = device->newDefaultLibrary();
if (!library) assert(false);
MTL::Function* E_ = library->newFunction(NS::String::string("E_", NS::StringEncoding::UTF8StringEncoding));
NS::Error* error = nullptr;
MTL::ComputePipelineState* pso = device->newComputePipelineState(E_, &error);
if (!pso) {
std::cerr << error->localizedDescription()->utf8String() << "\n";
assert(false);
}
// all buffers are effectively shared on apple silicon
// due to unified memory architecture you can get away with
// forgetting to call didModifyRange.
// API contract says you must call didModifyRange, but the driver doesn’t enforce that on apple silicon
outputs = device->newBuffer(4, MTL::ResourceStorageModeManaged);
input0 = device->newBuffer(4, MTL::ResourceStorageModeManaged);
const float a{10.0};
memcpy(input0->contents(), &a, sizeof(float));
// input0->didModifyRange(NS::Range::Make( 0, sizeof(float)));
MTL::CommandBuffer* cmd_buff = command_queue->commandBuffer();
MTL::ComputeCommandEncoder* cmd_enc = cmd_buff->computeCommandEncoder();
// pass buffers to compute shaders
cmd_enc->setComputePipelineState(pso);
cmd_enc->setBuffer(outputs, 0, 0);
cmd_enc->setBuffer(input0, 0, 1);
cmd_enc->dispatchThreadgroups(MTL::Size({1, 1, 1}), MTL::Size({1, 1, 1}));
cmd_enc->endEncoding();
// cmd_buff->addCompletedHandler([](const MTL::CommandBuffer* ignored) {
// float* out = static_cast<float*>(outputs->contents());
// if (out != nullptr) {
// printf("%.2f\n", *out);
// }
// });
cmd_buff->commit();
// figure out runloop so we can read output buf
cmd_buff->waitUntilCompleted();
float* out = static_cast<float*>(outputs->contents());
if (out != nullptr) {
printf("%.2f\n", *out);
}
}
// gputrace via capture manager
// metal binary archive
#include <vector>
#include <iostream>
#include "sys/mman.h"
#include "sys/stat.h"
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "lib/Metal.hpp"
#define max(x,y) ((x>y)?x:y)
#define int64 long
#define half __fp16
#define uchar unsigned char
#define bool uchar
#define _NS_PRIVATE_SEL(accessor) (Private::Selector::s_k##accessor)
int main() {
// both represents a metal context
MTL::Device* device = MTL::CreateSystemDefaultDevice();
MTL::CommandQueue* command_queue = device->newCommandQueue();
// capture
bool success;
MTL::CaptureManager* captureManager = MTL::CaptureManager::sharedCaptureManager();
success = captureManager->supportsDestination( MTL::CaptureDestinationGPUTraceDocument );
if (!success) {
__builtin_printf( "Capture support is not enabled\n");
assert( false );
}
MTL::CaptureDescriptor* pCaptureDescriptor = MTL::CaptureDescriptor::alloc()->init();
pCaptureDescriptor->setDestination( MTL::CaptureDestinationGPUTraceDocument );
pCaptureDescriptor->setOutputURL( NS::URL::fileURLWithPath(NS::String::string("/tmp/compute.gputrace", NS::StringEncoding::ASCIIStringEncoding)) );
pCaptureDescriptor->setCaptureObject(device);
NS::Error *pError = nullptr;
success = captureManager->startCapture( pCaptureDescriptor, &pError );
std::cout << device->name()->utf8String() << "\n";
const char* shaderSrc = R"(
#include <metal_stdlib>
using namespace metal;
kernel void E_(device unsigned char* data0, const device char* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
auto val0 = (long)(*(data1+0));
auto val1 = (long)(*(data1+1));
auto val2 = (long)(*(data1+2));
auto val3 = (long)(*(data1+3));
*(data0+0) = static_cast<uchar>(val0);
*(data0+1) = static_cast<uchar>(val1);
*(data0+2) = static_cast<uchar>(val2);
*(data0+3) = static_cast<uchar>(val3);
}
)";
// const char* shaderSrc1 = R"(
// #include <metal_stdlib>
// using namespace metal;
// kernel void E_(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
// float val0 = *(data1+0);
// float val1 = *(data2+0);
// *(data0+0) = (val0+val1);
// }
// )";
// HOW TO LOAD METAL BINARY ARCHIVE?
NS::Error* libError = nullptr;
MTL::CompileOptions* options = MTL::CompileOptions::alloc()->init();
MTL::Library* library = device->newLibrary(NS::String::string(shaderSrc, NS::StringEncoding::UTF8StringEncoding), options, &libError);
if (!library) assert(false);
// NS::Error* e = nullptr;
// dispatch_data_t dispatch_data = dispatch_data_create(xlibrary, sizeof(*xlibrary), dispatch_get_main_queue(), NULL);
// MTL::Library* library = device->newLibrary(dispatch_data, &e);
// device->newLibrary()
// // BinaryArchive
// NS::Error* archive_err = nullptr;
// MTL::BinaryArchiveDescriptor* archive_desc = MTL::BinaryArchiveDescriptor::alloc()->init();
// MTL::BinaryArchive* archive = device->newBinaryArchive(archive_desc, &archive_err);
// MTL::ComputePipelineDescriptor* compute_desc = MTL::ComputePipelineDescriptor::alloc()->init();
// NS::Error* c_err = nullptr;
// compute_desc->setComputeFunction(library->newFunction(NS::String::string("E_", NS::StringEncoding::ASCIIStringEncoding)));
// archive->addComputePipelineFunctions(compute_desc, &c_err);
// NS::Error* url_Error = nullptr;
// auto url = NS::URL::alloc()->initFileURLWithPath(NS::String::string("/tmp/tmpyslkfs", NS::StringEncoding::UTF8StringEncoding));
// bool is_success = archive->serializeToURL(url, &url_Error);
// if(!is_success) {
// std::cerr << "serializeToURL" << "\n";
// assert(false);
// }
MTL::Function* E_ = library->newFunction(NS::String::string("E_", NS::StringEncoding::ASCIIStringEncoding));
if (!E_) {
assert(false);
}
MTL::ComputePipelineDescriptor* compute_desc = MTL::ComputePipelineDescriptor::alloc()->init();
compute_desc->setComputeFunction(E_);
NS::Error* error = nullptr;
MTL::ComputePipelineState* pso = device->newComputePipelineState(E_, &error);
if (!pso) {
std::cerr << error->localizedDescription()->utf8String() << "\n";
assert(false);
}
MTL::CommandBuffer* cmd_buff = command_queue->commandBuffer();
MTL::ComputeCommandEncoder* cmd_enc = cmd_buff->computeCommandEncoder();
MTL::Buffer* outputs = device->newBuffer(4, MTL::ResourceStorageModeManaged);
MTL::Buffer* input0 = device->newBuffer(4, MTL::ResourceStorageModeManaged);
std::vector<char> a{-1, -2, -3, -4};
memcpy(input0->contents(), a.data(), a.size() * sizeof(char));
input0->didModifyRange(NS::Range::Make( 0, a.size() * sizeof(char)));
// // pass buffers to compute shaders
cmd_enc->setComputePipelineState(pso);
cmd_enc->setBuffer(outputs, 0, 0);
cmd_enc->setBuffer(input0, 0, 1);
cmd_enc->dispatchThreadgroups(MTL::Size({1, 1, 1}), MTL::Size({1, 1, 1}));
cmd_enc->endEncoding();
// MTL::BlitCommandEncoder* bc_enc = cmd_buff->blitCommandEncoder();
// bc_enc->synchronizeResource(outputs);
// bc_enc->endEncoding();
cmd_buff->commit();
cmd_buff->waitUntilCompleted();
// // captureManager->stopCapture();
auto* out = static_cast<unsigned char*>(outputs->contents());
if (out != nullptr) {
for (int i=0; i<4; ++i) {
printf("%u, ", out[i]);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment