Skip to content

Instantly share code, notes, and snippets.

@jankolf
Last active February 2, 2023 21:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jankolf/cd52c6d0ba9d8e427792d15dc5392c38 to your computer and use it in GitHub Desktop.
Save jankolf/cd52c6d0ba9d8e427792d15dc5392c38 to your computer and use it in GitHub Desktop.
Pytorch CUDA Extension
python setup.py install terminal output
$> python setup.py install
running install
running bdist_egg
running egg_info
writing reduce_cuda.egg-info/PKG-INFO
writing dependency_links to reduce_cuda.egg-info/dependency_links.txt
writing top-level names to reduce_cuda.egg-info/top_level.txt
/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/utils/cpp_extension.py:387: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
warnings.warn(msg.format('we could not find ninja.'))
reading manifest file 'reduce_cuda.egg-info/SOURCES.txt'
writing manifest file 'reduce_cuda.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/utils/cpp_extension.py:788: UserWarning: The detected CUDA version (11.4) has a minor version mismatch with the version that was used to compile PyTorch (11.3). Most likely this shouldn't be a problem.
warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda))
building 'reduce_cuda' extension
gcc -pthread -B /home/user/anaconda3/envs/torch11/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/include -I/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/include/TH -I/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/user/anaconda3/envs/torch11/include/python3.8 -c reduce_cuda.cpp -o build/temp.linux-x86_64-3.8/reduce_cuda.o -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -DTORCH_EXTENSION_NAME=reduce_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
/usr/local/cuda/bin/nvcc -I/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/include -I/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/include/TH -I/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/user/anaconda3/envs/torch11/include/python3.8 -c reduce_cuda_kernel.cu -o build/temp.linux-x86_64-3.8/reduce_cuda_kernel.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -DTORCH_EXTENSION_NAME=reduce_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_75,code=sm_75 -std=c++14
g++ -pthread -shared -B /home/user/anaconda3/envs/torch11/compiler_compat -L/home/user/anaconda3/envs/torch11/lib -Wl,-rpath=/home/user/anaconda3/envs/torch11/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.8/reduce_cuda.o build/temp.linux-x86_64-3.8/reduce_cuda_kernel.o -L/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/torch/lib -L/usr/local/cuda/lib64 -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda_cu -ltorch_cuda_cpp -o build/lib.linux-x86_64-3.8/reduce_cuda.cpython-38-x86_64-linux-gnu.so
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.8/reduce_cuda.cpython-38-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for reduce_cuda.cpython-38-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/reduce_cuda.py to reduce_cuda.cpython-38.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying reduce_cuda.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying reduce_cuda.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying reduce_cuda.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying reduce_cuda.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.reduce_cuda.cpython-38: module references __file__
creating 'dist/reduce_cuda-0.0.0-py3.8-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing reduce_cuda-0.0.0-py3.8-linux-x86_64.egg
removing '/home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/reduce_cuda-0.0.0-py3.8-linux-x86_64.egg' (and everything under it)
creating /home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/reduce_cuda-0.0.0-py3.8-linux-x86_64.egg
Extracting reduce_cuda-0.0.0-py3.8-linux-x86_64.egg to /home/user/anaconda3/envs/torch11/lib/python3.8/site-packages
reduce-cuda 0.0.0 is already the active version in easy-install.pth
Installed /home/user/anaconda3/envs/torch11/lib/python3.8/site-packages/reduce_cuda-0.0.0-py3.8-linux-x86_64.egg
Processing dependencies for reduce-cuda==0.0.0
Finished processing dependencies for reduce-cuda==0.0.0
#include <torch/extension.h>
#include <vector>
torch::Tensor reduce_cuda(
torch::Tensor matrix
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor reduce(torch::Tensor matrix)
{
CHECK_INPUT(matrix);
return reduce_cuda(matrix);
}
/*
* PYBIND
*/
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("reduce", &reduce, "Reduce matrix with CUDA.");
}
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cstdio>
namespace {
template <typename scalar_t>
__global__ void reduce_cuda_kernel(
const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> matrix,
torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> output
)
{
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const int x = blockIdx.x * blockDim.x + threadIdx.x;
if ((y < output.size(0)) && (x < output.size(1)))
{
printf("Hello from CUDA If.\n");
for(int i=0;i<matrix.size(2);i++)
{
output[y][x][i] = matrix[y][x][i];
}
}
}
} // Namespace
torch::Tensor reduce_cuda(torch::Tensor matrix)
{
// Define output like our input matrix
auto output = torch::zeros_like(matrix);
// Fixed block size of 32x32 threads
const dim3 threads(32, 32, 1);
// Calculate grid size based on matrix
const auto H = matrix.size(0);
const auto W = matrix.size(1);
const dim3 blocks(ceilf(H/32.0), ceilf(W/32.0), 1);
// Call the kernel
AT_DISPATCH_FLOATING_TYPES(output.scalar_type(), "reduce_cuda", ([&] {
reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
matrix.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>());
}));
cudaDeviceSynchronize();
printf("Cuda Error: %d \n", cudaGetLastError());
return output;
}
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='reduce_cuda',
ext_modules=[
CUDAExtension('reduce_cuda', [
'reduce_cuda.cpp',
'reduce_cuda_kernel.cu',
]),
],
cmdclass={
'build_ext': BuildExtension
}
)
import torch
# Import compiled module
import reduce_cuda
if __name__ == "__main__":
device = torch.device("cuda:1")
example = (torch.rand((1, 5, 2)) * 100).to(device)
reduced = reduce_cuda.reduce(example)
print("Example:", example)
print("Reduced:", reduced)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment