Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
import torch
from cupy.cuda import function
from pynvrtc.compiler import Program
from collections import namedtuple
a = torch.randn(1,4,4).cuda()
b = torch.zeros(a.size()).cuda()
kernel = '''
extern "C"
__global__ void flip(float *dst, const float *src, int w, int total)
int i = blockIdx.x * blockDim.x + threadIdx.x;
if(i >= total)
dst[i] = src[(i / w) * w + (w - (i % w) - 1)];
program = Program(kernel, '')
ptx = program.compile()
m = function.Module()
f = m.get_function('flip')
Stream = namedtuple('Stream', ['ptr'])
s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
f(grid=(1,1,1), block=(1024,1,1), args=[b.data_ptr(), a.data_ptr(), a.size(-1), a.numel()],
print a
print b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment