Skip to content

Instantly share code, notes, and snippets.

@szagoruyko
Created January 9, 2017 21:01
Show Gist options
  • Star 26 You must be signed in to star a gist
  • Fork 9 You must be signed in to fork a gist
  • Save szagoruyko/440c561f7fce5f1b20e6154d801e6033 to your computer and use it in GitHub Desktop.
Save szagoruyko/440c561f7fce5f1b20e6154d801e6033 to your computer and use it in GitHub Desktop.
import pycuda.autoinit
import pycuda.driver as drv
import numpy as np
import torch
x = torch.cuda.FloatTensor(8)
from pycuda.compiler import SourceModule
mod = SourceModule("""
__global__ void multiply_them(float *dest, float *a, float *b)
{
const int i = threadIdx.x;
dest[i] = a[i] * b[i];
}
""")
multiply_them = mod.get_function("multiply_them")
class Holder(pycuda.driver.PointerHolderBase):
def __init__(self, t):
super(Holder, self).__init__()
self.t = t
self.gpudata = t.data_ptr()
def get_pointer():
return self.t.data_ptr()
a = np.random.randn(400).astype(np.float32)
b = np.random.randn(400).astype(np.float32)
a = torch.from_numpy(a).cuda()
b = torch.from_numpy(b).cuda()
dest = torch.Tensor(a.size()).cuda()
multiply_them(
Holder(dest),
Holder(a),
Holder(b),
block=(400,1,1), grid=(1,1))
torch.cuda.synchronize()
print dest-a*b
@themightyoarfish
Copy link

themightyoarfish commented Jul 20, 2018

Do you have a suggestion how to use the Pointer to create a GPUArray and then go from that back to tensor?

@WhatAShot
Copy link

In my view, pycuda can not get access to the ptr in the beginning, so you add a line "x = torch.cuda.FloatTensor(8)" and make it right. Is there any graceful resolution?

@Emerald01
Copy link

This sample code solved my bug, in particular, I find that if I have a torch tensor, and push it to GPU, for example,
data = data.cuda(), then the pycuda function called later will throw this error.

func._set_block_shape(*block)
pycuda._driver.LogicError: cuFuncSetBlockShape failed: invalid resource handle

No way to solve it. I feel like it is not quite something about ptr initialization as hinted by @WhatAShot, because my function will fail anyway no matter it has that tensor as input or not, the problem just comes as long as data.cuda() before a pycuda function call.

Until I find this code, and the magic line x = torch.cuda.FloatTensor(8) makes the problem disappear, feel like pytorch.cuda has some strange behavior that conflict with pytorch in some way? Any comment?

@timothylimyl
Copy link

@Emerald01

Same issue faced here, I wish there were an elegant way to convert from GPUarray (pycuda) to pytorch tensor and vice versa.

Anyone manage to get it to work for 2D Blocks?

@benjamindkilleen
Copy link

This single gist has been a lifesaver, as I've been using pycuda and torch together for a while now without understanding the cryptic cuda bugs that resulted. I might only add that I believe torch.cuda.init() can be called instead of the "magic line" x = torch.cuda.FloatTensor(8) As far as I can tell, this line simply causes torch to automatically initialize its cuda context.

@tjyuyao
Copy link

tjyuyao commented Jan 8, 2022

@benjamindkilleen @Emerald01 @WhatAShot

Thanks for pointing out the core issue and providing the insightful discussion. I found using import pycuda.autoprimaryctx instead of import pycuda.autoinit will do the trick through the help of these links: issue 285, pycuda.autoprimaryctx docs. From the documentation:

The module pycuda.autoprimaryctx is similar to pycuda.autoinit, except that it retains the device primary context instead of creating a new context in pycuda.tools.make_default_context().

A modified version of this gist can be found as a fork.

@tjyuyao
Copy link

tjyuyao commented Jan 9, 2022

@timothylimyl

I have written an easy helper class for multi-dimensional pytorch tensor access here.

@RoyAmoyal
Copy link

RoyAmoyal commented Nov 15, 2023

what happens if I want to integrate it with a neural network? I want to do it in the forward, some conv2d can be applied before and after. How can I do that? how do I determine the block size etc?
@tjyuyao

@tjyuyao
Copy link

tjyuyao commented Nov 16, 2023

@RoyAmoyal you can use pytorch's autograd.Function api, and implement forward and backward pass in separate pycuda functions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment