Created
December 14, 2023 14:55
-
-
Save Kiterai/a0c440aed2f58397f002d11afe1bdaf5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <filesystem> | |
#include <fstream> | |
#include <iostream> | |
#include <random> | |
#include <vulkan/vulkan.hpp> | |
struct Param { | |
uint32_t matAH; | |
uint32_t matAWBH; | |
uint32_t matBW; | |
}; | |
struct Buffer { | |
vk::UniqueBuffer buffer; | |
vk::UniqueDeviceMemory memory; | |
}; | |
auto allocBuffer(vk::Device device, const vk::PhysicalDeviceMemoryProperties &memProps, vk::DeviceSize bufSize, | |
vk::BufferUsageFlags usage, vk::MemoryPropertyFlags reqMemProp) { | |
vk::BufferCreateInfo bufferCreateInfo; | |
bufferCreateInfo.size = bufSize; | |
bufferCreateInfo.usage = usage; | |
bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive; | |
Buffer buf; | |
buf.buffer = device.createBufferUnique(bufferCreateInfo); | |
vk::MemoryRequirements memReq = device.getBufferMemoryRequirements(buf.buffer.get()); | |
vk::MemoryAllocateInfo memAllocInfo; | |
memAllocInfo.allocationSize = memReq.size; | |
bool suitableMemoryTypeFound = false; | |
for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { | |
if (memReq.memoryTypeBits & (1 << i) && | |
(memProps.memoryTypes[i].propertyFlags & reqMemProp) == reqMemProp) { | |
memAllocInfo.memoryTypeIndex = i; | |
suitableMemoryTypeFound = true; | |
break; | |
} | |
} | |
if (!suitableMemoryTypeFound) { | |
std::cerr << "適切なメモリタイプが存在しません。" << std::endl; | |
exit(-1); | |
} | |
buf.memory = device.allocateMemoryUnique(memAllocInfo); | |
device.bindBufferMemory(buf.buffer.get(), buf.memory.get(), 0); | |
return buf; | |
} | |
int main() { | |
vk::InstanceCreateInfo createInfo; | |
vk::UniqueInstance instance = vk::createInstanceUnique(createInfo); | |
std::vector<vk::PhysicalDevice> physicalDevices = instance->enumeratePhysicalDevices(); | |
vk::PhysicalDevice physicalDevice; | |
bool existsSuitablePhysicalDevice = false; | |
uint32_t computeQueueFamilyIndex; | |
for (size_t i = 0; i < physicalDevices.size(); i++) { | |
std::vector<vk::QueueFamilyProperties> queueProps = physicalDevices[i].getQueueFamilyProperties(); | |
bool existsComputeQueue = false; | |
for (size_t j = 0; j < queueProps.size(); j++) { | |
if (queueProps[j].queueFlags & vk::QueueFlagBits::eCompute) { | |
existsComputeQueue = true; | |
computeQueueFamilyIndex = j; | |
break; | |
} | |
} | |
if (existsComputeQueue) { | |
physicalDevice = physicalDevices[i]; | |
existsSuitablePhysicalDevice = true; | |
break; | |
} | |
} | |
if (!existsSuitablePhysicalDevice) { | |
std::cerr << "使用可能な物理デバイスがありません。" << std::endl; | |
return -1; | |
} | |
vk::DeviceCreateInfo devCreateInfo; | |
vk::DeviceQueueCreateInfo queueCreateInfo[1]; | |
queueCreateInfo[0].queueFamilyIndex = computeQueueFamilyIndex; | |
queueCreateInfo[0].queueCount = 1; | |
float queuePriorities[1] = {1.0}; | |
queueCreateInfo[0].pQueuePriorities = queuePriorities; | |
devCreateInfo.pQueueCreateInfos = queueCreateInfo; | |
devCreateInfo.queueCreateInfoCount = 1; | |
vk::UniqueDevice device = physicalDevice.createDeviceUnique(devCreateInfo); | |
vk::Queue computeQueue = device->getQueue(computeQueueFamilyIndex, 0); | |
vk::CommandPoolCreateInfo cmdPoolCreateInfo; | |
cmdPoolCreateInfo.queueFamilyIndex = computeQueueFamilyIndex; | |
vk::UniqueCommandPool cmdPool = device->createCommandPoolUnique(cmdPoolCreateInfo); | |
vk::CommandBufferAllocateInfo cmdBufAllocInfo; | |
cmdBufAllocInfo.commandPool = cmdPool.get(); | |
cmdBufAllocInfo.commandBufferCount = 1; | |
cmdBufAllocInfo.level = vk::CommandBufferLevel::ePrimary; | |
std::vector<vk::UniqueCommandBuffer> cmdBufs = | |
device->allocateCommandBuffersUnique(cmdBufAllocInfo); | |
vk::PhysicalDeviceMemoryProperties memProps = physicalDevice.getMemoryProperties(); | |
uint32_t matrixAHeight = 8; | |
uint32_t matrixAWidth = 8; | |
uint32_t matrixBHeight = 8; | |
uint32_t matrixBWidth = 8; | |
if(matrixAWidth != matrixBHeight) { | |
std::cerr << "不正な行列積" << std::endl; | |
return -1; | |
} | |
Param param; | |
param.matAH = matrixAHeight; | |
param.matAWBH = matrixAWidth; | |
param.matBW = matrixBWidth; | |
std::vector<float> matrixA(matrixAHeight * matrixAWidth); | |
std::vector<float> matrixB(matrixBHeight * matrixBWidth); | |
std::vector<float> matrixC(matrixAHeight * matrixBWidth); | |
auto matrixABuf = allocBuffer(device.get(), memProps, sizeof(float) * matrixAHeight * matrixAWidth, vk::BufferUsageFlagBits::eStorageBuffer, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible); | |
auto matrixBBuf = allocBuffer(device.get(), memProps, sizeof(float) * matrixBHeight * matrixBWidth, vk::BufferUsageFlagBits::eStorageBuffer, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible); | |
auto matrixCBuf = allocBuffer(device.get(), memProps, sizeof(float) * matrixAHeight * matrixBWidth, vk::BufferUsageFlagBits::eStorageBuffer, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible); | |
std::random_device seed; | |
std::mt19937 engine(seed()); | |
std::uniform_real_distribution<float> rand_dist(-1.0f, 1.0f); | |
std::cout << "matrixA:" << std::endl; | |
{ | |
float *matrixAData = static_cast<float*>(device->mapMemory(matrixABuf.memory.get(), 0, sizeof(float) * matrixAHeight * matrixAWidth)); | |
for(int i = 0; i < matrixAHeight; i++) { | |
for(int j = 0; j < matrixAWidth; j++) { | |
matrixAData[i * matrixAWidth + j] = rand_dist(engine); | |
std::cout << matrixAData[i * matrixAWidth + j] << ' '; | |
} | |
std::cout << std::endl; | |
} | |
device->unmapMemory(matrixABuf.memory.get()); | |
} | |
std::cout << "matrixB:" << std::endl; | |
{ | |
float *matrixBData = static_cast<float*>(device->mapMemory(matrixBBuf.memory.get(), 0, sizeof(float) * matrixAHeight * matrixAWidth)); | |
for(int i = 0; i < matrixBHeight; i++) { | |
for(int j = 0; j < matrixBWidth; j++) { | |
matrixBData[i * matrixBWidth + j] = rand_dist(engine); | |
std::cout << matrixBData[i * matrixAWidth + j] << ' '; | |
} | |
std::cout << std::endl; | |
} | |
device->unmapMemory(matrixBBuf.memory.get()); | |
} | |
vk::DescriptorSetLayoutBinding descSetLayoutBinding[3]; | |
descSetLayoutBinding[0].binding = 0; | |
descSetLayoutBinding[0].descriptorType = vk::DescriptorType::eStorageBuffer; | |
descSetLayoutBinding[0].descriptorCount = 1; | |
descSetLayoutBinding[0].stageFlags = vk::ShaderStageFlagBits::eCompute; | |
descSetLayoutBinding[1].binding = 1; | |
descSetLayoutBinding[1].descriptorType = vk::DescriptorType::eStorageBuffer; | |
descSetLayoutBinding[1].descriptorCount = 1; | |
descSetLayoutBinding[1].stageFlags = vk::ShaderStageFlagBits::eCompute; | |
descSetLayoutBinding[2].binding = 2; | |
descSetLayoutBinding[2].descriptorType = vk::DescriptorType::eStorageBuffer; | |
descSetLayoutBinding[2].descriptorCount = 1; | |
descSetLayoutBinding[2].stageFlags = vk::ShaderStageFlagBits::eCompute; | |
vk::DescriptorSetLayoutCreateInfo descSetLayoutCreateInfo{}; | |
descSetLayoutCreateInfo.bindingCount = std::size(descSetLayoutBinding); | |
descSetLayoutCreateInfo.pBindings = descSetLayoutBinding; | |
vk::UniqueDescriptorSetLayout descSetLayout = device->createDescriptorSetLayoutUnique(descSetLayoutCreateInfo); | |
vk::DescriptorPoolSize descPoolSize[1]; | |
descPoolSize[0].type = vk::DescriptorType::eStorageBuffer; | |
descPoolSize[0].descriptorCount = 3; | |
vk::DescriptorPoolCreateInfo descPoolCreateInfo; | |
descPoolCreateInfo.poolSizeCount = std::size(descPoolSize); | |
descPoolCreateInfo.pPoolSizes = descPoolSize; | |
descPoolCreateInfo.maxSets = 1; | |
vk::UniqueDescriptorPool descPool = device->createDescriptorPoolUnique(descPoolCreateInfo); | |
vk::DescriptorSetAllocateInfo descSetAllocInfo; | |
auto descSetLayouts = {descSetLayout.get()}; | |
descSetAllocInfo.descriptorPool = descPool.get(); | |
descSetAllocInfo.descriptorSetCount = descSetLayouts.size(); | |
descSetAllocInfo.pSetLayouts = descSetLayouts.begin(); | |
std::vector<vk::UniqueDescriptorSet> descSets = device->allocateDescriptorSetsUnique(descSetAllocInfo); | |
vk::DescriptorBufferInfo descBufInfoA[1]; | |
descBufInfoA[0].buffer = matrixABuf.buffer.get(); | |
descBufInfoA[0].offset = 0; | |
descBufInfoA[0].range = sizeof(float) * matrixAHeight * matrixAWidth; | |
vk::DescriptorBufferInfo descBufInfoB[1]; | |
descBufInfoB[0].buffer = matrixBBuf.buffer.get(); | |
descBufInfoB[0].offset = 0; | |
descBufInfoB[0].range = sizeof(float) * matrixBHeight * matrixBWidth; | |
vk::DescriptorBufferInfo descBufInfoC[1]; | |
descBufInfoC[0].buffer = matrixCBuf.buffer.get(); | |
descBufInfoC[0].offset = 0; | |
descBufInfoC[0].range = sizeof(float) * matrixAHeight * matrixBWidth; | |
vk::WriteDescriptorSet writeDescSetA; | |
writeDescSetA.dstSet = descSets[0].get(); | |
writeDescSetA.dstBinding = 0; | |
writeDescSetA.dstArrayElement = 0; | |
writeDescSetA.descriptorType = vk::DescriptorType::eStorageBuffer; | |
writeDescSetA.descriptorCount = 1; | |
writeDescSetA.pBufferInfo = descBufInfoA; | |
vk::WriteDescriptorSet writeDescSetB; | |
writeDescSetB.dstSet = descSets[0].get(); | |
writeDescSetB.dstBinding = 1; | |
writeDescSetB.dstArrayElement = 0; | |
writeDescSetB.descriptorType = vk::DescriptorType::eStorageBuffer; | |
writeDescSetB.descriptorCount = 1; | |
writeDescSetB.pBufferInfo = descBufInfoB; | |
vk::WriteDescriptorSet writeDescSetC; | |
writeDescSetC.dstSet = descSets[0].get(); | |
writeDescSetC.dstBinding = 2; | |
writeDescSetC.dstArrayElement = 0; | |
writeDescSetC.descriptorType = vk::DescriptorType::eStorageBuffer; | |
writeDescSetC.descriptorCount = 1; | |
writeDescSetC.pBufferInfo = descBufInfoC; | |
device->updateDescriptorSets({ writeDescSetA, writeDescSetB, writeDescSetC }, {}); | |
auto pipelineDescSetLayouts = {descSetLayout.get()}; | |
vk::PushConstantRange pushConstantRanges[1]; | |
pushConstantRanges[0].offset = 0; | |
pushConstantRanges[0].size = sizeof(Param); | |
pushConstantRanges[0].stageFlags = vk::ShaderStageFlagBits::eCompute; | |
vk::PipelineLayoutCreateInfo layoutCreateInfo; | |
layoutCreateInfo.setLayoutCount = pipelineDescSetLayouts.size(); | |
layoutCreateInfo.pSetLayouts = pipelineDescSetLayouts.begin(); | |
layoutCreateInfo.pushConstantRangeCount = std::size(pushConstantRanges); | |
layoutCreateInfo.pPushConstantRanges = pushConstantRanges; | |
vk::UniquePipelineLayout pipelineLayout = device->createPipelineLayoutUnique(layoutCreateInfo); | |
size_t spvFileSz = std::filesystem::file_size("shader.comp.spv"); | |
std::ifstream spvFile("shader.comp.spv", std::ios_base::binary); | |
std::vector<char> spvFileData(spvFileSz); | |
spvFile.read(spvFileData.data(), spvFileSz); | |
vk::ShaderModuleCreateInfo computeShaderCreateInfo; | |
computeShaderCreateInfo.codeSize = spvFileSz; | |
computeShaderCreateInfo.pCode = reinterpret_cast<const uint32_t *>(spvFileData.data()); | |
vk::UniqueShaderModule computeShader = device->createShaderModuleUnique(computeShaderCreateInfo); | |
vk::ComputePipelineCreateInfo pipelineCreateInfo; | |
pipelineCreateInfo.layout = pipelineLayout.get(); | |
pipelineCreateInfo.stage.stage = vk::ShaderStageFlagBits::eCompute; | |
pipelineCreateInfo.stage.module = computeShader.get(); | |
pipelineCreateInfo.stage.pName = "main"; | |
vk::UniquePipeline pipeline = device->createComputePipelineUnique(nullptr, pipelineCreateInfo).value; | |
vk::CommandBufferBeginInfo cmdBeginInfo; | |
cmdBufs[0]->begin(cmdBeginInfo); | |
cmdBufs[0]->bindPipeline(vk::PipelineBindPoint::eCompute, pipeline.get()); | |
cmdBufs[0]->bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipelineLayout.get(), 0, {descSets[0].get()}, {}); | |
cmdBufs[0]->pushConstants(pipelineLayout.get(), vk::ShaderStageFlagBits::eCompute, 0, sizeof(Param), ¶m); | |
cmdBufs[0]->dispatch(matrixAHeight, matrixBWidth, 1); | |
cmdBufs[0]->end(); | |
vk::CommandBuffer submitCmdBuf[1] = {cmdBufs[0].get()}; | |
vk::SubmitInfo submitInfo; | |
submitInfo.commandBufferCount = 1; | |
submitInfo.pCommandBuffers = submitCmdBuf; | |
computeQueue.submit({submitInfo}, nullptr); | |
computeQueue.waitIdle(); | |
std::cout << "matrixA x matrixB = " << std::endl; | |
float *calcData = static_cast<float*>(device->mapMemory(matrixCBuf.memory.get(), 0, sizeof(float) * matrixAHeight * matrixBWidth)); | |
for(int i = 0; i < matrixAHeight; i++) { | |
for(int j = 0; j < matrixBWidth; j++) { | |
std::cout << calcData[i * matrixBWidth + j] << ' '; | |
} | |
std::cout << std::endl; | |
} | |
device->unmapMemory(matrixCBuf.memory.get()); | |
computeQueue.waitIdle(); | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#version 450 | |
#extension GL_KHR_shader_subgroup_arithmetic : enable | |
layout(push_constant) uniform Param { | |
uint matAH; | |
uint matAWBH; | |
uint matBW; | |
} param; | |
layout(set = 0, binding = 0) readonly buffer InputA { | |
float elem[]; | |
} matrixInA; | |
layout(set = 0, binding = 1) readonly buffer InputB { | |
float elem[]; | |
} matrixInB; | |
layout(set = 0, binding = 2) buffer Output { | |
float elem[]; | |
} matrixOut; | |
layout (local_size_x = 8, local_size_y = 1, local_size_z = 1) in; | |
void main() { | |
float elem_prod = matrixInA.elem[gl_WorkGroupID.y * param.matAWBH + gl_LocalInvocationID.x] * matrixInB.elem[gl_LocalInvocationID.x * param.matBW + gl_WorkGroupID.x]; | |
elem_prod = subgroupAdd(elem_prod); | |
matrixOut.elem[gl_WorkGroupID.y * param.matBW + gl_WorkGroupID.x] = elem_prod; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment