Skip to content

Instantly share code, notes, and snippets.

@csking101
Created July 19, 2023 06:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save csking101/38dff660c27f832857387fedbc76c1e7 to your computer and use it in GitHub Desktop.
Save csking101/38dff660c27f832857387fedbc76c1e7 to your computer and use it in GitHub Desktop.
3D Convolution
#include <iostream>
#include <vector>
using namespace std;
vector<vector<vector<float>>> Conv3D(vector<vector<vector<float>>> image,
vector<vector<vector<float>>> kernel,
int stride,
int padding) {
//Handle the case for if you want a same convolution
int image_depth = image.size();
int image_height = image[0].size();
int image_width = image[0][0].size();
int kernel_depth = kernel.size();
int kernel_height = kernel[0].size();
int kernel_width = kernel[0][0].size();
int output_depth = (image_depth - kernel_depth + 2 * padding) / stride + 1;
int output_height = (image_height - kernel_height + 2 * padding) / stride + 1;
int output_width = (image_width - kernel_width + 2 * padding) / stride + 1;
// Initialize output tensor with zeros
vector<vector<vector<float>>> output(output_depth, vector<vector<float>>(output_height, vector<float>(output_width, 0.0)));
// Pad input tensor
vector<vector<vector<float>>> padded_image(image_depth + 2 * padding,
vector<vector<float>>(image_height + 2 * padding,
vector<float>(image_width + 2 * padding, 0.0)));
for (int d = 0; d < image_depth; d++) {
for (int i = 0; i < image_height; i++) {
for (int j = 0; j < image_width; j++) {
padded_image[d + padding][i + padding][j + padding] = image[d][i][j];
}
}
}
// Perform 3D convolution
for (int z = 0; z < output_depth; z++) {
for (int y = 0; y < output_height; y++) {
for (int x = 0; x < output_width; x++) {
for (int kd = 0; kd < kernel_depth; kd++) {
for (int kh = 0; kh < kernel_height; kh++) {
for (int kw = 0; kw < kernel_width; kw++) {
output[z][y][x] += padded_image[z * stride + kd][y * stride + kh][x * stride + kw] * kernel[kd][kh][kw];
}
}
}
}
}
}
return output;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment