Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Last active November 27, 2023 05:59
Show Gist options
  • Star 37 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
  • Save goldsborough/865e6717e64fbae75cdaf6c9914a130d to your computer and use it in GitHub Desktop.
Save goldsborough/865e6717e64fbae75cdaf6c9914a130d to your computer and use it in GitHub Desktop.
Convolution with cuDNN
#include <cudnn.h>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <opencv2/opencv.hpp>
#define checkCUDNN(expression) \
{ \
cudnnStatus_t status = (expression); \
if (status != CUDNN_STATUS_SUCCESS) { \
std::cerr << "Error on line " << __LINE__ << ": " \
<< cudnnGetErrorString(status) << std::endl; \
std::exit(EXIT_FAILURE); \
} \
}
cv::Mat load_image(const char* image_path) {
cv::Mat image = cv::imread(image_path, CV_LOAD_IMAGE_COLOR);
image.convertTo(image, CV_32FC3);
cv::normalize(image, image, 0, 1, cv::NORM_MINMAX);
std::cerr << "Input Image: " << image.rows << " x " << image.cols << " x "
<< image.channels() << std::endl;
return image;
}
void save_image(const char* output_filename,
float* buffer,
int height,
int width) {
cv::Mat output_image(height, width, CV_32FC3, buffer);
// Make negative values zero.
cv::threshold(output_image,
output_image,
/*threshold=*/0,
/*maxval=*/0,
cv::THRESH_TOZERO);
cv::normalize(output_image, output_image, 0.0, 255.0, cv::NORM_MINMAX);
output_image.convertTo(output_image, CV_8UC3);
cv::imwrite(output_filename, output_image);
std::cerr << "Wrote output to " << output_filename << std::endl;
}
int main(int argc, const char* argv[]) {
if (argc < 2) {
std::cerr << "usage: conv <image> [gpu=0] [sigmoid=0]" << std::endl;
std::exit(EXIT_FAILURE);
}
int gpu_id = (argc > 2) ? std::atoi(argv[2]) : 0;
std::cerr << "GPU: " << gpu_id << std::endl;
bool with_sigmoid = (argc > 3) ? std::atoi(argv[3]) : 0;
std::cerr << "With sigmoid: " << std::boolalpha << with_sigmoid << std::endl;
cv::Mat image = load_image(argv[1]);
cudaSetDevice(gpu_id);
cudnnHandle_t cudnn;
cudnnCreate(&cudnn);
cudnnTensorDescriptor_t input_descriptor;
checkCUDNN(cudnnCreateTensorDescriptor(&input_descriptor));
checkCUDNN(cudnnSetTensor4dDescriptor(input_descriptor,
/*format=*/CUDNN_TENSOR_NHWC,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/3,
/*image_height=*/image.rows,
/*image_width=*/image.cols));
cudnnFilterDescriptor_t kernel_descriptor;
checkCUDNN(cudnnCreateFilterDescriptor(&kernel_descriptor));
checkCUDNN(cudnnSetFilter4dDescriptor(kernel_descriptor,
/*dataType=*/CUDNN_DATA_FLOAT,
/*format=*/CUDNN_TENSOR_NCHW,
/*out_channels=*/3,
/*in_channels=*/3,
/*kernel_height=*/3,
/*kernel_width=*/3));
cudnnConvolutionDescriptor_t convolution_descriptor;
checkCUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor));
checkCUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor,
/*pad_height=*/1,
/*pad_width=*/1,
/*vertical_stride=*/1,
/*horizontal_stride=*/1,
/*dilation_height=*/1,
/*dilation_width=*/1,
/*mode=*/CUDNN_CROSS_CORRELATION,
/*computeType=*/CUDNN_DATA_FLOAT));
int batch_size{0}, channels{0}, height{0}, width{0};
checkCUDNN(cudnnGetConvolution2dForwardOutputDim(convolution_descriptor,
input_descriptor,
kernel_descriptor,
&batch_size,
&channels,
&height,
&width));
std::cerr << "Output Image: " << height << " x " << width << " x " << channels
<< std::endl;
cudnnTensorDescriptor_t output_descriptor;
checkCUDNN(cudnnCreateTensorDescriptor(&output_descriptor));
checkCUDNN(cudnnSetTensor4dDescriptor(output_descriptor,
/*format=*/CUDNN_TENSOR_NHWC,
/*dataType=*/CUDNN_DATA_FLOAT,
/*batch_size=*/1,
/*channels=*/3,
/*image_height=*/image.rows,
/*image_width=*/image.cols));
cudnnConvolutionFwdAlgo_t convolution_algorithm;
checkCUDNN(
cudnnGetConvolutionForwardAlgorithm(cudnn,
input_descriptor,
kernel_descriptor,
convolution_descriptor,
output_descriptor,
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
/*memoryLimitInBytes=*/0,
&convolution_algorithm));
size_t workspace_bytes{0};
checkCUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn,
input_descriptor,
kernel_descriptor,
convolution_descriptor,
output_descriptor,
convolution_algorithm,
&workspace_bytes));
std::cerr << "Workspace size: " << (workspace_bytes / 1048576.0) << "MB"
<< std::endl;
assert(workspace_bytes > 0);
void* d_workspace{nullptr};
cudaMalloc(&d_workspace, workspace_bytes);
int image_bytes = batch_size * channels * height * width * sizeof(float);
float* d_input{nullptr};
cudaMalloc(&d_input, image_bytes);
cudaMemcpy(d_input, image.ptr<float>(0), image_bytes, cudaMemcpyHostToDevice);
float* d_output{nullptr};
cudaMalloc(&d_output, image_bytes);
cudaMemset(d_output, 0, image_bytes);
// clang-format off
const float kernel_template[3][3] = {
{1, 1, 1},
{1, -8, 1},
{1, 1, 1}
};
// clang-format on
float h_kernel[3][3][3][3];
for (int kernel = 0; kernel < 3; ++kernel) {
for (int channel = 0; channel < 3; ++channel) {
for (int row = 0; row < 3; ++row) {
for (int column = 0; column < 3; ++column) {
h_kernel[kernel][channel][row][column] = kernel_template[row][column];
}
}
}
}
float* d_kernel{nullptr};
cudaMalloc(&d_kernel, sizeof(h_kernel));
cudaMemcpy(d_kernel, h_kernel, sizeof(h_kernel), cudaMemcpyHostToDevice);
const float alpha = 1.0f, beta = 0.0f;
checkCUDNN(cudnnConvolutionForward(cudnn,
&alpha,
input_descriptor,
d_input,
kernel_descriptor,
d_kernel,
convolution_descriptor,
convolution_algorithm,
d_workspace,
workspace_bytes,
&beta,
output_descriptor,
d_output));
if (with_sigmoid) {
cudnnActivationDescriptor_t activation_descriptor;
checkCUDNN(cudnnCreateActivationDescriptor(&activation_descriptor));
checkCUDNN(cudnnSetActivationDescriptor(activation_descriptor,
CUDNN_ACTIVATION_SIGMOID,
CUDNN_PROPAGATE_NAN,
/*relu_coef=*/0));
checkCUDNN(cudnnActivationForward(cudnn,
activation_descriptor,
&alpha,
output_descriptor,
d_output,
&beta,
output_descriptor,
d_output));
cudnnDestroyActivationDescriptor(activation_descriptor);
}
float* h_output = new float[image_bytes];
cudaMemcpy(h_output, d_output, image_bytes, cudaMemcpyDeviceToHost);
save_image("cudnn-out.png", h_output, height, width);
delete[] h_output;
cudaFree(d_kernel);
cudaFree(d_input);
cudaFree(d_output);
cudaFree(d_workspace);
cudnnDestroyTensorDescriptor(input_descriptor);
cudnnDestroyTensorDescriptor(output_descriptor);
cudnnDestroyFilterDescriptor(kernel_descriptor);
cudnnDestroyConvolutionDescriptor(convolution_descriptor);
cudnnDestroy(cudnn);
}
@aerodame
Copy link

It would be great if this example could come with a full prerequisites for Cuda toolkit and cuDNN as well as a Makefile that parallels the examples in cudnn. It is missing the instructions for opencv2 that is required in the headerfile.

@verneyleaf
Copy link

I install cudnn6.5 and cuda6.5 on jetson Tk1. When i run your code, i will get an error,
cudnnSetFilter4dDescriptor: too many arguments
Thanks for your answer

@adam-dziedzic
Copy link

@aerodame

The Makefile can be found here: Peter's blog post

For your convenience:

CXX := nvcc
TARGET := conv
CUDNN_PATH := cudnn
HEADERS := -I $(CUDNN_PATH)/include
LIBS := -L $(CUDNN_PATH)/lib64 -L/usr/local/lib
CXXFLAGS := -arch=sm_35 -std=c++11 -O2

all: conv

conv: $(TARGET).cu
	$(CXX) $(CXXFLAGS) $(HEADERS) $(LIBS) $(TARGET).cu -o $(TARGET) \
	-lcudnn -lopencv_imgcodecs -lopencv_imgproc -lopencv_core

.phony: clean

clean:
	rm $(TARGET) || echo -n ""

@AshburnLee
Copy link

I've read your blog, it helps a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment