Skip to content

Instantly share code, notes, and snippets.

@maweigert
Created December 2, 2019 20:23
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save maweigert/33aa0b104cc4e4b4e79aaa6d3c270a4c to your computer and use it in GitHub Desktop.
Save maweigert/33aa0b104cc4e4b4e79aaa6d3c270a4c to your computer and use it in GitHub Desktop.
"""
RL-Deconv test with pyopencl
Input data (http://bigwww.epfl.ch/deconvolution/bars/)
Bars-G10-P15-stack.tif
PSF-Bars-stack.tif
"""
import numpy as np
from tifffile import imread
from gputools import OCLArray, OCLProgram, fft, fft_convolve, fft_plan, OCLElementwiseKernel
from time import time
import argparse
# calculate b = a/b (where a is real and b is ocmplex) inplace
_divide_rc_inplace_second = OCLElementwiseKernel(
"float *a, cfloat_t * b",
"b[i] = cfloat_rdivide(a[i],b[i])",
"divide_rc_inplace")
# copy real->complex on the GPU
_copy_float_as_complex = OCLElementwiseKernel(
"float *a, cfloat_t * b",
"b[i] = cfloat_fromreal(a[i])",
"copy_float_as_complex")
def time_with_queue_finished():
"""get wall time after queue is finished (to benchmark GPU code)"""
from gputools import get_device
get_device().queue.finish()
return time()
def rl_deconv(data, h, niter=10):
""" RL deconvolution if 3D volume <data> with psf <h>"""
t_all = time_with_queue_finished()
if not (data.dtype.type==np.float32 and h.dtype.type ==np.float32):
raise ValueError("input data has to be of type np.float32")
if data.shape!=h.shape:
raise ValueError("data and h have to be same shape")
# transfer CPU-> GPU and create temporary buffer
# has to use complex type as there is no rfftn in reikna yet
data_g = OCLArray.from_array(data)
h0_g = OCLArray.from_array(np.fft.fftshift(h))
h0flip_g = OCLArray.from_array(np.fft.fftshift(h[::-1,::-1,::-1]))
u_g = OCLArray.empty(data.shape, np.complex64)
tmp_g = OCLArray.empty(data_g.shape, np.complex64)
h_g = OCLArray.empty(h.shape, np.complex64)
hflip_g = OCLArray.empty(h.shape, np.complex64)
# first transferring real and then copy to complex on GPU is faster
_copy_float_as_complex(data_g, u_g)
_copy_float_as_complex(h0_g, h_g)
_copy_float_as_complex(h0flip_g, hflip_g)
# create FFT plan
plan = fft_plan(data.shape)
# fft transform psfs
fft(h_g, inplace=True, plan=plan)
fft(hflip_g, inplace=True , plan=plan)
t_loop = time_with_queue_finished()
# RL steps
for i in range(niter):
fft_convolve(u_g, h_g,res_g=tmp_g,
plan=plan,kernel_is_fft=True)
_divide_rc_inplace_second(data_g, tmp_g)
fft_convolve(tmp_g, hflip_g,inplace=True,
plan=plan,kernel_is_fft=True)
u_g *= tmp_g
t_end = time_with_queue_finished()
print(f"time for {niter} RL-iterations:")
print(f" all: {t_end-t_all:.3f} s")
print(f" w/o transfer: {t_end-t_loop:.3f} s")
return u_g.real.get()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('-n','--niter', type = int, default = 100)
parser.add_argument('--plot', action = "store_true")
args = parser.parse_args()
# data loading and psf normalization
y = imread("Bars-G10-P15-stack.tif").astype(np.float32)
h = imread("PSF-Bars-stack.tif").astype(np.float32)
h /= np.sum(h)
# rl deconv
u = rl_deconv(y, h, niter = args.niter)
if args.plot:
# plot results
import matplotlib.pyplot as plt
_, axs = plt.subplots(2,2)
for _axs, dim in zip(axs,(0,2)):
for ax,_data in zip(_axs, (y,u)):
ax.imshow(np.max(_data,dim))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment