Skip to content

Instantly share code, notes, and snippets.

@Kiterai
Created December 14, 2023 14:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Kiterai/a0c440aed2f58397f002d11afe1bdaf5 to your computer and use it in GitHub Desktop.
Save Kiterai/a0c440aed2f58397f002d11afe1bdaf5 to your computer and use it in GitHub Desktop.
#include <filesystem>
#include <fstream>
#include <iostream>
#include <random>
#include <vulkan/vulkan.hpp>
struct Param {
uint32_t matAH;
uint32_t matAWBH;
uint32_t matBW;
};
struct Buffer {
vk::UniqueBuffer buffer;
vk::UniqueDeviceMemory memory;
};
auto allocBuffer(vk::Device device, const vk::PhysicalDeviceMemoryProperties &memProps, vk::DeviceSize bufSize,
vk::BufferUsageFlags usage, vk::MemoryPropertyFlags reqMemProp) {
vk::BufferCreateInfo bufferCreateInfo;
bufferCreateInfo.size = bufSize;
bufferCreateInfo.usage = usage;
bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive;
Buffer buf;
buf.buffer = device.createBufferUnique(bufferCreateInfo);
vk::MemoryRequirements memReq = device.getBufferMemoryRequirements(buf.buffer.get());
vk::MemoryAllocateInfo memAllocInfo;
memAllocInfo.allocationSize = memReq.size;
bool suitableMemoryTypeFound = false;
for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) {
if (memReq.memoryTypeBits & (1 << i) &&
(memProps.memoryTypes[i].propertyFlags & reqMemProp) == reqMemProp) {
memAllocInfo.memoryTypeIndex = i;
suitableMemoryTypeFound = true;
break;
}
}
if (!suitableMemoryTypeFound) {
std::cerr << "適切なメモリタイプが存在しません。" << std::endl;
exit(-1);
}
buf.memory = device.allocateMemoryUnique(memAllocInfo);
device.bindBufferMemory(buf.buffer.get(), buf.memory.get(), 0);
return buf;
}
int main() {
vk::InstanceCreateInfo createInfo;
vk::UniqueInstance instance = vk::createInstanceUnique(createInfo);
std::vector<vk::PhysicalDevice> physicalDevices = instance->enumeratePhysicalDevices();
vk::PhysicalDevice physicalDevice;
bool existsSuitablePhysicalDevice = false;
uint32_t computeQueueFamilyIndex;
for (size_t i = 0; i < physicalDevices.size(); i++) {
std::vector<vk::QueueFamilyProperties> queueProps = physicalDevices[i].getQueueFamilyProperties();
bool existsComputeQueue = false;
for (size_t j = 0; j < queueProps.size(); j++) {
if (queueProps[j].queueFlags & vk::QueueFlagBits::eCompute) {
existsComputeQueue = true;
computeQueueFamilyIndex = j;
break;
}
}
if (existsComputeQueue) {
physicalDevice = physicalDevices[i];
existsSuitablePhysicalDevice = true;
break;
}
}
if (!existsSuitablePhysicalDevice) {
std::cerr << "使用可能な物理デバイスがありません。" << std::endl;
return -1;
}
vk::DeviceCreateInfo devCreateInfo;
vk::DeviceQueueCreateInfo queueCreateInfo[1];
queueCreateInfo[0].queueFamilyIndex = computeQueueFamilyIndex;
queueCreateInfo[0].queueCount = 1;
float queuePriorities[1] = {1.0};
queueCreateInfo[0].pQueuePriorities = queuePriorities;
devCreateInfo.pQueueCreateInfos = queueCreateInfo;
devCreateInfo.queueCreateInfoCount = 1;
vk::UniqueDevice device = physicalDevice.createDeviceUnique(devCreateInfo);
vk::Queue computeQueue = device->getQueue(computeQueueFamilyIndex, 0);
vk::CommandPoolCreateInfo cmdPoolCreateInfo;
cmdPoolCreateInfo.queueFamilyIndex = computeQueueFamilyIndex;
vk::UniqueCommandPool cmdPool = device->createCommandPoolUnique(cmdPoolCreateInfo);
vk::CommandBufferAllocateInfo cmdBufAllocInfo;
cmdBufAllocInfo.commandPool = cmdPool.get();
cmdBufAllocInfo.commandBufferCount = 1;
cmdBufAllocInfo.level = vk::CommandBufferLevel::ePrimary;
std::vector<vk::UniqueCommandBuffer> cmdBufs =
device->allocateCommandBuffersUnique(cmdBufAllocInfo);
vk::PhysicalDeviceMemoryProperties memProps = physicalDevice.getMemoryProperties();
uint32_t matrixAHeight = 8;
uint32_t matrixAWidth = 8;
uint32_t matrixBHeight = 8;
uint32_t matrixBWidth = 8;
if(matrixAWidth != matrixBHeight) {
std::cerr << "不正な行列積" << std::endl;
return -1;
}
Param param;
param.matAH = matrixAHeight;
param.matAWBH = matrixAWidth;
param.matBW = matrixBWidth;
std::vector<float> matrixA(matrixAHeight * matrixAWidth);
std::vector<float> matrixB(matrixBHeight * matrixBWidth);
std::vector<float> matrixC(matrixAHeight * matrixBWidth);
auto matrixABuf = allocBuffer(device.get(), memProps, sizeof(float) * matrixAHeight * matrixAWidth, vk::BufferUsageFlagBits::eStorageBuffer, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible);
auto matrixBBuf = allocBuffer(device.get(), memProps, sizeof(float) * matrixBHeight * matrixBWidth, vk::BufferUsageFlagBits::eStorageBuffer, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible);
auto matrixCBuf = allocBuffer(device.get(), memProps, sizeof(float) * matrixAHeight * matrixBWidth, vk::BufferUsageFlagBits::eStorageBuffer, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible);
std::random_device seed;
std::mt19937 engine(seed());
std::uniform_real_distribution<float> rand_dist(-1.0f, 1.0f);
std::cout << "matrixA:" << std::endl;
{
float *matrixAData = static_cast<float*>(device->mapMemory(matrixABuf.memory.get(), 0, sizeof(float) * matrixAHeight * matrixAWidth));
for(int i = 0; i < matrixAHeight; i++) {
for(int j = 0; j < matrixAWidth; j++) {
matrixAData[i * matrixAWidth + j] = rand_dist(engine);
std::cout << matrixAData[i * matrixAWidth + j] << ' ';
}
std::cout << std::endl;
}
device->unmapMemory(matrixABuf.memory.get());
}
std::cout << "matrixB:" << std::endl;
{
float *matrixBData = static_cast<float*>(device->mapMemory(matrixBBuf.memory.get(), 0, sizeof(float) * matrixAHeight * matrixAWidth));
for(int i = 0; i < matrixBHeight; i++) {
for(int j = 0; j < matrixBWidth; j++) {
matrixBData[i * matrixBWidth + j] = rand_dist(engine);
std::cout << matrixBData[i * matrixAWidth + j] << ' ';
}
std::cout << std::endl;
}
device->unmapMemory(matrixBBuf.memory.get());
}
vk::DescriptorSetLayoutBinding descSetLayoutBinding[3];
descSetLayoutBinding[0].binding = 0;
descSetLayoutBinding[0].descriptorType = vk::DescriptorType::eStorageBuffer;
descSetLayoutBinding[0].descriptorCount = 1;
descSetLayoutBinding[0].stageFlags = vk::ShaderStageFlagBits::eCompute;
descSetLayoutBinding[1].binding = 1;
descSetLayoutBinding[1].descriptorType = vk::DescriptorType::eStorageBuffer;
descSetLayoutBinding[1].descriptorCount = 1;
descSetLayoutBinding[1].stageFlags = vk::ShaderStageFlagBits::eCompute;
descSetLayoutBinding[2].binding = 2;
descSetLayoutBinding[2].descriptorType = vk::DescriptorType::eStorageBuffer;
descSetLayoutBinding[2].descriptorCount = 1;
descSetLayoutBinding[2].stageFlags = vk::ShaderStageFlagBits::eCompute;
vk::DescriptorSetLayoutCreateInfo descSetLayoutCreateInfo{};
descSetLayoutCreateInfo.bindingCount = std::size(descSetLayoutBinding);
descSetLayoutCreateInfo.pBindings = descSetLayoutBinding;
vk::UniqueDescriptorSetLayout descSetLayout = device->createDescriptorSetLayoutUnique(descSetLayoutCreateInfo);
vk::DescriptorPoolSize descPoolSize[1];
descPoolSize[0].type = vk::DescriptorType::eStorageBuffer;
descPoolSize[0].descriptorCount = 3;
vk::DescriptorPoolCreateInfo descPoolCreateInfo;
descPoolCreateInfo.poolSizeCount = std::size(descPoolSize);
descPoolCreateInfo.pPoolSizes = descPoolSize;
descPoolCreateInfo.maxSets = 1;
vk::UniqueDescriptorPool descPool = device->createDescriptorPoolUnique(descPoolCreateInfo);
vk::DescriptorSetAllocateInfo descSetAllocInfo;
auto descSetLayouts = {descSetLayout.get()};
descSetAllocInfo.descriptorPool = descPool.get();
descSetAllocInfo.descriptorSetCount = descSetLayouts.size();
descSetAllocInfo.pSetLayouts = descSetLayouts.begin();
std::vector<vk::UniqueDescriptorSet> descSets = device->allocateDescriptorSetsUnique(descSetAllocInfo);
vk::DescriptorBufferInfo descBufInfoA[1];
descBufInfoA[0].buffer = matrixABuf.buffer.get();
descBufInfoA[0].offset = 0;
descBufInfoA[0].range = sizeof(float) * matrixAHeight * matrixAWidth;
vk::DescriptorBufferInfo descBufInfoB[1];
descBufInfoB[0].buffer = matrixBBuf.buffer.get();
descBufInfoB[0].offset = 0;
descBufInfoB[0].range = sizeof(float) * matrixBHeight * matrixBWidth;
vk::DescriptorBufferInfo descBufInfoC[1];
descBufInfoC[0].buffer = matrixCBuf.buffer.get();
descBufInfoC[0].offset = 0;
descBufInfoC[0].range = sizeof(float) * matrixAHeight * matrixBWidth;
vk::WriteDescriptorSet writeDescSetA;
writeDescSetA.dstSet = descSets[0].get();
writeDescSetA.dstBinding = 0;
writeDescSetA.dstArrayElement = 0;
writeDescSetA.descriptorType = vk::DescriptorType::eStorageBuffer;
writeDescSetA.descriptorCount = 1;
writeDescSetA.pBufferInfo = descBufInfoA;
vk::WriteDescriptorSet writeDescSetB;
writeDescSetB.dstSet = descSets[0].get();
writeDescSetB.dstBinding = 1;
writeDescSetB.dstArrayElement = 0;
writeDescSetB.descriptorType = vk::DescriptorType::eStorageBuffer;
writeDescSetB.descriptorCount = 1;
writeDescSetB.pBufferInfo = descBufInfoB;
vk::WriteDescriptorSet writeDescSetC;
writeDescSetC.dstSet = descSets[0].get();
writeDescSetC.dstBinding = 2;
writeDescSetC.dstArrayElement = 0;
writeDescSetC.descriptorType = vk::DescriptorType::eStorageBuffer;
writeDescSetC.descriptorCount = 1;
writeDescSetC.pBufferInfo = descBufInfoC;
device->updateDescriptorSets({ writeDescSetA, writeDescSetB, writeDescSetC }, {});
auto pipelineDescSetLayouts = {descSetLayout.get()};
vk::PushConstantRange pushConstantRanges[1];
pushConstantRanges[0].offset = 0;
pushConstantRanges[0].size = sizeof(Param);
pushConstantRanges[0].stageFlags = vk::ShaderStageFlagBits::eCompute;
vk::PipelineLayoutCreateInfo layoutCreateInfo;
layoutCreateInfo.setLayoutCount = pipelineDescSetLayouts.size();
layoutCreateInfo.pSetLayouts = pipelineDescSetLayouts.begin();
layoutCreateInfo.pushConstantRangeCount = std::size(pushConstantRanges);
layoutCreateInfo.pPushConstantRanges = pushConstantRanges;
vk::UniquePipelineLayout pipelineLayout = device->createPipelineLayoutUnique(layoutCreateInfo);
size_t spvFileSz = std::filesystem::file_size("shader.comp.spv");
std::ifstream spvFile("shader.comp.spv", std::ios_base::binary);
std::vector<char> spvFileData(spvFileSz);
spvFile.read(spvFileData.data(), spvFileSz);
vk::ShaderModuleCreateInfo computeShaderCreateInfo;
computeShaderCreateInfo.codeSize = spvFileSz;
computeShaderCreateInfo.pCode = reinterpret_cast<const uint32_t *>(spvFileData.data());
vk::UniqueShaderModule computeShader = device->createShaderModuleUnique(computeShaderCreateInfo);
vk::ComputePipelineCreateInfo pipelineCreateInfo;
pipelineCreateInfo.layout = pipelineLayout.get();
pipelineCreateInfo.stage.stage = vk::ShaderStageFlagBits::eCompute;
pipelineCreateInfo.stage.module = computeShader.get();
pipelineCreateInfo.stage.pName = "main";
vk::UniquePipeline pipeline = device->createComputePipelineUnique(nullptr, pipelineCreateInfo).value;
vk::CommandBufferBeginInfo cmdBeginInfo;
cmdBufs[0]->begin(cmdBeginInfo);
cmdBufs[0]->bindPipeline(vk::PipelineBindPoint::eCompute, pipeline.get());
cmdBufs[0]->bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipelineLayout.get(), 0, {descSets[0].get()}, {});
cmdBufs[0]->pushConstants(pipelineLayout.get(), vk::ShaderStageFlagBits::eCompute, 0, sizeof(Param), &param);
cmdBufs[0]->dispatch(matrixAHeight, matrixBWidth, 1);
cmdBufs[0]->end();
vk::CommandBuffer submitCmdBuf[1] = {cmdBufs[0].get()};
vk::SubmitInfo submitInfo;
submitInfo.commandBufferCount = 1;
submitInfo.pCommandBuffers = submitCmdBuf;
computeQueue.submit({submitInfo}, nullptr);
computeQueue.waitIdle();
std::cout << "matrixA x matrixB = " << std::endl;
float *calcData = static_cast<float*>(device->mapMemory(matrixCBuf.memory.get(), 0, sizeof(float) * matrixAHeight * matrixBWidth));
for(int i = 0; i < matrixAHeight; i++) {
for(int j = 0; j < matrixBWidth; j++) {
std::cout << calcData[i * matrixBWidth + j] << ' ';
}
std::cout << std::endl;
}
device->unmapMemory(matrixCBuf.memory.get());
computeQueue.waitIdle();
return 0;
}
#version 450
#extension GL_KHR_shader_subgroup_arithmetic : enable
layout(push_constant) uniform Param {
uint matAH;
uint matAWBH;
uint matBW;
} param;
layout(set = 0, binding = 0) readonly buffer InputA {
float elem[];
} matrixInA;
layout(set = 0, binding = 1) readonly buffer InputB {
float elem[];
} matrixInB;
layout(set = 0, binding = 2) buffer Output {
float elem[];
} matrixOut;
layout (local_size_x = 8, local_size_y = 1, local_size_z = 1) in;
void main() {
float elem_prod = matrixInA.elem[gl_WorkGroupID.y * param.matAWBH + gl_LocalInvocationID.x] * matrixInB.elem[gl_LocalInvocationID.x * param.matBW + gl_WorkGroupID.x];
elem_prod = subgroupAdd(elem_prod);
matrixOut.elem[gl_WorkGroupID.y * param.matBW + gl_WorkGroupID.x] = elem_prod;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment