Skip to content

Instantly share code, notes, and snippets.

@psteinb
Created April 25, 2024 12:14
Show Gist options
  • Save psteinb/bc52a4820b1ed743d8dd8c4d24524b7c to your computer and use it in GitHub Desktop.
Save psteinb/bc52a4820b1ed743d8dd8c4d24524b7c to your computer and use it in GitHub Desktop.
Running a CUFFT fft transform with callback
import cupy as cp
import cupyx.scipy.fft as cufft
import pytest
import numpy as np
import scipy.fft
scipy.fft.set_global_backend(cufft)
@pytest.fixture
def callbacks():
data_shape = (256,256)
code = r'''
__device__ cufftComplex CB_ConvertInputC(
void *dataIn,
size_t offset,
void *callerInfo,
void *sharedPtr)
{
cufftComplex x;
x.x = 1.;
x.y = 0.;
return x;
}
__device__ cufftCallbackLoadC d_loadCallbackPtr = CB_ConvertInputC;
'''
# https: // docs.nvidia.com / cuda / cufft / index.html # cufft-callback-routines
backwardCallback = r'''
#define DATASIZE {arrsize}
__device__ __host__ inline
cufftComplex ComplexMul(cufftComplex a, cufftComplex b)
{{
cufftComplex c;
c.x = a.x * b.x - a.y * b.y;
c.y = a.x * b.y + a.y * b.x;
return c;
}}
__device__ void FresnelCallback(
void *dataOut,
size_t offset,
cufftComplex element,
void *callerInfo,
void *sharedPtr)
{{
cufftComplex *kernel = (cufftComplex*)callerInfo;
cufftComplex kernelEl = kernel[offset];
kernelEl.x *= 1./DATASIZE;
kernelEl.y *=-1./DATASIZE;
cufftComplex value = ComplexMul(element, kernelEl);
((cufftComplex*)dataOut)[offset] = value;
}}
__device__ cufftCallbackStoreC d_storeCallbackPtr = FresnelCallback;
'''.format(arrsize=data_shape[0]*data_shape[1])
yield data_shape, backwardCallback
def test_backward_single_callback(callbacks):
data_shape, backwardCallback = callbacks
x = cp.zeros(data_shape, dtype=cp.cfloat)
kernel = cp.ones(data_shape, dtype=cp.cfloat)
with cp.fft.config.set_cufft_callbacks(cb_store = backwardCallback,
cb_store_aux_arr = kernel ):
y = cp.fft.fft2(cp.from_dlpack(x)) #cp.fft.fft2(cp.from_dlpack(x))
assert y.shape == x.shape
assert y.dtype == x.dtype
def test_backward_double_vanilla(callbacks):
data_shape, backwardCallback = callbacks
x = cp.zeros(data_shape, dtype=cp.cfloat)
kernel = cp.ones(data_shape, dtype=cp.cfloat)
y = cp.fft.fft2(x)
y_ = cp.fft.ifft2(y, norm = "forward")
assert y.shape == x.shape
def test_backward_double_callback_cpfft(callbacks):
data_shape, backwardCallback = callbacks
x = cp.zeros(data_shape, dtype=cp.cfloat)
kernel = cp.ones(data_shape, dtype=cp.cfloat)
with cp.fft.config.set_cufft_callbacks(cb_store = backwardCallback,
cb_store_aux_arr = kernel ):
y = cp.fft.fft2(x)
#cp.fft.fft2(cp.from_dlpack(x))
y_ = cp.fft.ifft2(y, norm = "forward")
#cp.fft.ifft2(y, norm = "forward")
assert y.shape == x.shape
def test_backward_double_callback_cufft_planahead(callbacks):
data_shape, backwardCallback = callbacks
x = cp.zeros(data_shape, dtype=cp.cfloat)
kernel = cp.ones(data_shape, dtype=cp.cfloat)
plan = cufft.get_fft_plan(x,shape=data_shape,value_type="C2C")
with cp.fft.config.set_cufft_callbacks(cb_store = backwardCallback,
cb_store_aux_arr = kernel ):
y = cufft.fft2(x, plan=plan)
#cp.fft.fft2(cp.from_dlpack(x))
y = cufft.ifft2(y, norm = "forward", plan=plan)
#cp.fft.ifft2(y, norm = "forward")
assert y.shape == x.shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment