Skip to content

Instantly share code, notes, and snippets.

@Smerity
Forked from szagoruyko/cupy-pytorch-ptx.py
Created May 21, 2017 23:21
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Smerity/1e023085a49c66827128aa4e72e22bf1 to your computer and use it in GitHub Desktop.
Save Smerity/1e023085a49c66827128aa4e72e22bf1 to your computer and use it in GitHub Desktop.
CuPy example for PyTorch updated to support Python 3
import torch
from cupy.cuda import function
from pynvrtc.compiler import Program
from collections import namedtuple
a = torch.randn(1,4,4).cuda()
b = torch.zeros(a.size()).cuda()
kernel = '''
extern "C"
__global__ void flip(float *dst, const float *src, int w, int total)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if(i >= total)
return;
dst[i] = src[(i / w) * w + (w - (i % w) - 1)];
}
'''
program = Program(kernel.encode(), 'flip.cu'.encode())
ptx = program.compile()
m = function.Module()
m.load(bytes(ptx.encode()))
f = m.get_function('flip')
Stream = namedtuple('Stream', ['ptr'])
s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
f(grid=(1,1,1), block=(1024,1,1), args=[b.data_ptr(), a.data_ptr(), a.size(-1), a.numel()],
stream=s)
print(a)
print(b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment