Skip to content

Instantly share code, notes, and snippets.

@Mikea15
Created December 18, 2024 09:39
Show Gist options
  • Save Mikea15/7afb75fa1bd61aba24819cce90e9d184 to your computer and use it in GitHub Desktop.
Save Mikea15/7afb75fa1bd61aba24819cce90e9d184 to your computer and use it in GitHub Desktop.
Hello Compute
#include <GL/glew.h>
#include <SDL2/SDL.h>
#include <iostream>
#include <vector>
#ifdef _DEBUG
#pragma comment(lib, "SDL2maind")
#else
#pragma comment(lib, "SDL2main")
#endif
class ComputeShaderHandler {
private:
GLuint computeProgram;
GLuint inputBuffer1, inputBuffer2, outputBuffer;
// Shader source (you'll replace this with your actual compute shader code)
const char* computeShaderSource = R"(
#version 430
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
// Input buffers
layout(std430, binding = 0) buffer InputBuffer1 {
float data1[];
};
layout(std430, binding = 1) buffer InputBuffer2 {
float data2[];
};
// Output buffer
layout(std430, binding = 2) buffer OutputBuffer {
float processedData[];
};
uniform int dataSize;
void main() {
uint index = gl_GlobalInvocationID.x;
if (index < dataSize) {
// Example processing: simple element-wise addition
processedData[index] = data1[index] + data2[index];
}
}
)";
public:
ComputeShaderHandler() {
// Initialize GLEW and create compute shader
SDL_Window* window;
SDL_GLContext glContext;
// Initialize SDL
if (SDL_Init(SDL_INIT_VIDEO) < 0) {
std::cerr << "SDL could not initialize! SDL_Error: " << SDL_GetError() << std::endl;
return;
}
// Set OpenGL attributes before creating window
SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 4);
SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 3); // For compute shaders
SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);
// Create window
window = SDL_CreateWindow("Compute Shader Test",
SDL_WINDOWPOS_UNDEFINED,
SDL_WINDOWPOS_UNDEFINED,
800, 600,
SDL_WINDOW_OPENGL | SDL_WINDOW_SHOWN
);
if (!window) {
std::cerr << "Window could not be created! SDL_Error: " << SDL_GetError() << std::endl;
SDL_Quit();
return;
}
// Create OpenGL context
glContext = SDL_GL_CreateContext(window);
if (!glContext) {
std::cerr << "OpenGL context could not be created! SDL_Error: " << SDL_GetError() << std::endl;
SDL_DestroyWindow(window);
SDL_Quit();
return;
}
// Initialize GLEW AFTER creating a valid OpenGL context
GLenum err = glewInit();
if (GLEW_OK != err) {
// Problem: glewInit failed, something is seriously wrong
std::cerr << "Error: " << glewGetErrorString(err) << std::endl;
SDL_GL_DeleteContext(glContext);
SDL_DestroyWindow(window);
SDL_Quit();
return;
}
initComputeShader();
}
void initComputeShader() {
// Compile compute shader
GLuint shader = glCreateShader(GL_COMPUTE_SHADER);
glShaderSource(shader, 1, &computeShaderSource, nullptr);
glCompileShader(shader);
// Check shader compilation
GLint success;
glGetShaderiv(shader, GL_COMPILE_STATUS, &success);
if (!success) {
GLchar infoLog[512];
glGetShaderInfoLog(shader, sizeof(infoLog), nullptr, infoLog);
std::cerr << "Compute Shader Compilation Failed:\n" << infoLog << std::endl;
return;
}
// Create shader program
computeProgram = glCreateProgram();
glAttachShader(computeProgram, shader);
glLinkProgram(computeProgram);
// Check program linking
glGetProgramiv(computeProgram, GL_LINK_STATUS, &success);
if (!success) {
GLchar infoLog[512];
glGetProgramInfoLog(computeProgram, sizeof(infoLog), nullptr, infoLog);
std::cerr << "Shader Program Linking Failed:\n" << infoLog << std::endl;
return;
}
glDeleteShader(shader);
}
void setupBuffers(const std::vector<float>& input1, const std::vector<float>& input2) {
// Create and bind input buffer 1
glGenBuffers(1, &inputBuffer1);
glBindBuffer(GL_SHADER_STORAGE_BUFFER, inputBuffer1);
glBufferData(GL_SHADER_STORAGE_BUFFER, input1.size() * sizeof(float), input1.data(), GL_DYNAMIC_DRAW);
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, inputBuffer1);
// Create and bind input buffer 2
glGenBuffers(1, &inputBuffer2);
glBindBuffer(GL_SHADER_STORAGE_BUFFER, inputBuffer2);
glBufferData(GL_SHADER_STORAGE_BUFFER, input2.size() * sizeof(float), input2.data(), GL_DYNAMIC_DRAW);
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, inputBuffer2);
// Create output buffer
glGenBuffers(1, &outputBuffer);
glBindBuffer(GL_SHADER_STORAGE_BUFFER, outputBuffer);
glBufferData(GL_SHADER_STORAGE_BUFFER, input1.size() * sizeof(float), nullptr, GL_DYNAMIC_DRAW);
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, outputBuffer);
}
void dispatchCompute(size_t dataSize) {
// Use the compute shader program
glUseProgram(computeProgram);
// Set uniform for data size
GLint dataSizeLocation = glGetUniformLocation(computeProgram, "dataSize");
glUniform1i(dataSizeLocation, dataSize);
// Dispatch compute shader
glDispatchCompute((dataSize + 15) / 16, 1, 1);
// Ensure compute shader operations are finished
glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
}
std::vector<float> retrieveOutputData(size_t dataSize) {
std::vector<float> outputData(dataSize);
// Bind the output buffer and retrieve data
glBindBuffer(GL_SHADER_STORAGE_BUFFER, outputBuffer);
float* mappedData = (float*)glMapBuffer(GL_SHADER_STORAGE_BUFFER, GL_READ_ONLY);
if (mappedData) {
std::copy(mappedData, mappedData + dataSize, outputData.begin());
glUnmapBuffer(GL_SHADER_STORAGE_BUFFER);
}
else {
std::cerr << "Failed to map buffer" << std::endl;
}
return outputData;
}
// Cleanup method
~ComputeShaderHandler() {
glDeleteBuffers(1, &inputBuffer1);
glDeleteBuffers(1, &inputBuffer2);
glDeleteBuffers(1, &outputBuffer);
glDeleteProgram(computeProgram);
}
};
// Example usage
int main(int argc, char* argv[])
{
// SDL and OpenGL initialization would go here
std::vector<float> input1 = { 1.0f, 2.0f, 3.0f, 4.0f };
std::vector<float> input2 = { 5.0f, 6.0f, 7.0f, 8.0f };
ComputeShaderHandler computeHandler;
// Setup buffers with input data
computeHandler.setupBuffers(input1, input2);
// Dispatch compute shader
computeHandler.dispatchCompute(input1.size());
// Retrieve processed data
std::vector<float> outputData = computeHandler.retrieveOutputData(input1.size());
// Print results
for (float val : outputData) {
std::cout << val << " ";
}
std::cout << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment