Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Created May 20, 2024 04:07
Show Gist options
  • Save bigsnarfdude/380d7ca6a13b2139aac79d32e3dab426 to your computer and use it in GitHub Desktop.
Save bigsnarfdude/380d7ca6a13b2139aac79d32e3dab426 to your computer and use it in GitHub Desktop.
cuda_basics.py
url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/1600px-Cute_dog.jpg?20140729055059'
import torch, os, math, gzip, pickle
import matplotlib.pyplot as plt
from urllib.request import urlretrieve
from pathlib import Path
from torch import tensor
import torchvision as tv
import torchvision.transforms.functional as tvf
from torchvision import io
from torch.utils.cpp_extension import load_inline
path_img = Path('puppy.jpg')
if not path_img.exists(): urlretrieve(url, path_img)
img = io.read_image('puppy.jpg')
print(img.shape)
img[:2,:3,:4]
img2 = tvf.resize(img, 150, antialias=True)
ch,h,w = img2.shape
ch,h,w,h*w
def rgb2grey_py(x):
c,h,w = x.shape
n = h*w
x = x.flatten()
res = torch.empty(n, dtype=x.dtype, device=x.device)
for i in range(n): res[i] = 0.2989*x[i] + 0.5870*x[i+n] + 0.1140*x[i+2*n]
return res.view(h,w)
img_g = rgb2grey_py(img2)
def run_kernel(f, times, *args):
for i in range(times): f(i, *args)
def rgb2grey_k(i, x, out, n):
out[i] = 0.2989*x[i] + 0.5870*x[i+n] + 0.1140*x[i+2*n]
def rgb2grey_pyk(x):
c,h,w = x.shape
n = h*w
x = x.flatten()
res = torch.empty(n, dtype=x.dtype, device=x.device)
run_kernel(rgb2grey_k, h*w, x, res, n)
return res.view(h,w)
img_g = rgb2grey_pyk(img2)
def blk_kernel(f, blocks, threads, *args):
for i in range(blocks):
for j in range(threads): f(i, j, threads, *args)
def rgb2grey_bk(blockidx, threadidx, blockdim, x, out, n):
i = blockidx*blockdim + threadidx
if i<n: out[i] = 0.2989*x[i] + 0.5870*x[i+n] + 0.1140*x[i+2*n]
def rgb2grey_pybk(x):
c,h,w = x.shape
n = h*w
x = x.flatten()
res = torch.empty(n, dtype=x.dtype, device=x.device)
threads = 256
blocks = int(math.ceil(h*w/threads))
blk_kernel(rgb2grey_bk, blocks, threads, x, res, n)
return res.view(h,w)
img_g = rgb2grey_pybk(img2)
img_g ==rgb2grey_pyk(img2)
os.environ['CUDA_LAUNCH_BLOCKING']='1'
%load_ext wurlitzer
def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False):
return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
extra_cuda_cflags=["-O2"] if opt else [], verbose=verbose, name="inline_ext")
cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
'''
cuda_src = cuda_begin + r'''
__global__ void rgb_to_grayscale_kernel(unsigned char* x, unsigned char* out, int n) {
int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i<n) out[i] = 0.2989*x[i] + 0.5870*x[i+n] + 0.1140*x[i+2*n];
}
torch::Tensor rgb_to_grayscale(torch::Tensor input) {
CHECK_INPUT(input);
int h = input.size(1);
int w = input.size(2);
printf("h*w: %d*%d\n", h, w);
auto output = torch::empty({h,w}, input.options());
int threads = 256;
rgb_to_grayscale_kernel<<<cdiv(w*h,threads), threads>>>(
input.data_ptr<unsigned char>(), output.data_ptr<unsigned char>(), w*h);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}'''
cpp_src = "torch::Tensor rgb_to_grayscale(torch::Tensor input);"
module = load_cuda(cuda_src, cpp_src, ['rgb_to_grayscale'], verbose=True)
imgc = img.contiguous().cuda()
res = module.rgb_to_grayscale(imgc).cpu()
h,w = res.shape
h,w,h*w
import gzip,pickle
from urllib.request import urlretrieve
from pathlib import Path
from torch import tensor
MNIST_URL='https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true'
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path_gz = path_data/'mnist.pkl.gz'
if not path_gz.exists(): urlretrieve(MNIST_URL, path_gz)
with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
x_train,y_train,x_valid,y_valid = map(tensor, (x_train,y_train,x_valid,y_valid))
x_train.shape,x_train.type()
imgs = x_train.reshape((-1,28,28))
imgs.shape
torch.manual_seed(1)
weights = torch.randn(784,10)
weights
m1 = x_valid[:5]
m2 = weights
m1.shape,m2.shape
ar,ac = m1.shape # n_rows * n_cols
br,bc = m2.shape
(ar,ac),(br,bc)
t1 = torch.zeros(ar, bc)
t1.shape
for i in range(ar): # 5
for j in range(bc): # 10
for k in range(ac): # 784
t1[i,j] += m1[i,k] * m2[k,j]
t1.shape
import numpy as np
np.set_printoptions(precision=2, linewidth=140)
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
def matmul(a,b):
(ar,ac),(br,bc) = a.shape,b.shape
c = torch.zeros(ar, bc)
for i in range(ar):
for j in range(bc):
for k in range(ac): c[i,j] += a[i,k] * b[k,j]
return c
_=matmul(m1, m2)
ar*bc*ac
from types import SimpleNamespace as ns
def blk_kernel2d(f, blocks, threads, *args):
for i0 in range(blocks.y):
for i1 in range(blocks.x):
for j0 in range(threads.y):
for j1 in range(threads.x): f(ns(x=i0,y=i1), ns(x=j0,y=j1), threads, *args)
def matmul_bk(blockidx, threadidx, blockdim, m, n, out, h, w, k):
r = blockidx.y*blockdim.y + threadidx.y
c = blockidx.x*blockdim.x + threadidx.x
if (r>=h or c>=w): return
o = 0.
for i in range(k): o += m[r*k+i] * n[i*w+c]
out[r*w+c] = o
def matmul_2d(m, n):
h,k = m.shape
k2,w = n.shape
assert k==k2, "Size mismatch!"
output = torch.zeros(h, w, dtype=m.dtype)
tpb = ns(x=16,y=16)
blocks = ns(x=math.ceil(w/tpb.x), y=math.ceil(h/tpb.y))
blk_kernel2d(matmul_bk, blocks, tpb,
m.flatten(), n.flatten(), output.flatten(), h, w, k)
return output
res = matmul_2d(m1, m2)
torch.isclose(t1, res).all()
def matmul(a,b):
(ar,ac),(br,bc) = a.shape,b.shape
c = torch.zeros(ar, bc)
for i in range(ar): c[i] = (a[i,:,None] * b).sum(dim=0)
return c
torch.isclose(t1,matmul(m1, m2)).all()
_=matmul(m1, m2)
m1 = x_train
tr = matmul(m1, m2)
tr.shape
_=matmul(m1, m2)
%time _=matmul(m1, m2)
ar,ac = m1.shape
br,bc = m2.shape
ar*bc*ac
cuda_src = cuda_begin + r'''
__global__ void rgb_to_grayscale_kernel(unsigned char* x, unsigned char* out, int w, int h) {
int c = blockIdx.x*blockDim.x + threadIdx.x;
int r = blockIdx.y*blockDim.y + threadIdx.y;
if (c<w && r<h) {
int i = r*w + c;
int n = h*w;
out[i] = 0.2989*x[i] + 0.5870*x[i+n] + 0.1140*x[i+2*n];
}
}
torch::Tensor rgb_to_grayscale(torch::Tensor input) {
CHECK_INPUT(input);
int h = input.size(1);
int w = input.size(2);
torch::Tensor output = torch::empty({h,w}, input.options());
dim3 tpb(16,16);
dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
rgb_to_grayscale_kernel<<<blocks, tpb>>>(
input.data_ptr<unsigned char>(), output.data_ptr<unsigned char>(), w, h);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}'''
module = load_cuda(cuda_src, cpp_src, ['rgb_to_grayscale'])
res = module.rgb_to_grayscale(imgc).cpu()
res = module.rgb_to_grayscale(imgc).cpu()
cuda_src = cuda_begin + r'''
__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
int r = blockIdx.y*blockDim.y + threadIdx.y;
int c = blockIdx.x*blockDim.x + threadIdx.x;
if (r>=h || c>=w) return;
float o = 0;
for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
out[r*w+c] = o;
}
torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
CHECK_INPUT(m); CHECK_INPUT(n);
int h = m.size(0);
int w = n.size(1);
int k = m.size(1);
TORCH_CHECK(k==n.size(0), "Size mismatch!");
auto output = torch::zeros({h, w}, m.options());
dim3 tpb(16,16);
dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
matmul_k<<<blocks, tpb>>>(
m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}
'''
cpp_src = "torch::Tensor matmul(torch::Tensor m, torch::Tensor n);"
module = load_cuda(cuda_src, cpp_src, ['matmul'])
m1c,m2c = m1.contiguous().cuda(), m2.contiguous().cuda()
torch.isclose(tr, module.matmul(m1c, m2c).cpu(), atol=1e-5).all()
%time
res=module.matmul(m1c, m2c).cpu()
res.shape
torch.isclose(tr, (m1c@m2c).cpu(), atol=1e-5).all()
%timeit -n 10 _=(m1c@m2c).cpu()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment