-
-
Save maweigert/33aa0b104cc4e4b4e79aaa6d3c270a4c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
RL-Deconv test with pyopencl | |
Input data (http://bigwww.epfl.ch/deconvolution/bars/) | |
Bars-G10-P15-stack.tif | |
PSF-Bars-stack.tif | |
""" | |
import numpy as np | |
from tifffile import imread | |
from gputools import OCLArray, OCLProgram, fft, fft_convolve, fft_plan, OCLElementwiseKernel | |
from time import time | |
import argparse | |
# calculate b = a/b (where a is real and b is ocmplex) inplace | |
_divide_rc_inplace_second = OCLElementwiseKernel( | |
"float *a, cfloat_t * b", | |
"b[i] = cfloat_rdivide(a[i],b[i])", | |
"divide_rc_inplace") | |
# copy real->complex on the GPU | |
_copy_float_as_complex = OCLElementwiseKernel( | |
"float *a, cfloat_t * b", | |
"b[i] = cfloat_fromreal(a[i])", | |
"copy_float_as_complex") | |
def time_with_queue_finished(): | |
"""get wall time after queue is finished (to benchmark GPU code)""" | |
from gputools import get_device | |
get_device().queue.finish() | |
return time() | |
def rl_deconv(data, h, niter=10): | |
""" RL deconvolution if 3D volume <data> with psf <h>""" | |
t_all = time_with_queue_finished() | |
if not (data.dtype.type==np.float32 and h.dtype.type ==np.float32): | |
raise ValueError("input data has to be of type np.float32") | |
if data.shape!=h.shape: | |
raise ValueError("data and h have to be same shape") | |
# transfer CPU-> GPU and create temporary buffer | |
# has to use complex type as there is no rfftn in reikna yet | |
data_g = OCLArray.from_array(data) | |
h0_g = OCLArray.from_array(np.fft.fftshift(h)) | |
h0flip_g = OCLArray.from_array(np.fft.fftshift(h[::-1,::-1,::-1])) | |
u_g = OCLArray.empty(data.shape, np.complex64) | |
tmp_g = OCLArray.empty(data_g.shape, np.complex64) | |
h_g = OCLArray.empty(h.shape, np.complex64) | |
hflip_g = OCLArray.empty(h.shape, np.complex64) | |
# first transferring real and then copy to complex on GPU is faster | |
_copy_float_as_complex(data_g, u_g) | |
_copy_float_as_complex(h0_g, h_g) | |
_copy_float_as_complex(h0flip_g, hflip_g) | |
# create FFT plan | |
plan = fft_plan(data.shape) | |
# fft transform psfs | |
fft(h_g, inplace=True, plan=plan) | |
fft(hflip_g, inplace=True , plan=plan) | |
t_loop = time_with_queue_finished() | |
# RL steps | |
for i in range(niter): | |
fft_convolve(u_g, h_g,res_g=tmp_g, | |
plan=plan,kernel_is_fft=True) | |
_divide_rc_inplace_second(data_g, tmp_g) | |
fft_convolve(tmp_g, hflip_g,inplace=True, | |
plan=plan,kernel_is_fft=True) | |
u_g *= tmp_g | |
t_end = time_with_queue_finished() | |
print(f"time for {niter} RL-iterations:") | |
print(f" all: {t_end-t_all:.3f} s") | |
print(f" w/o transfer: {t_end-t_loop:.3f} s") | |
return u_g.real.get() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='') | |
parser.add_argument('-n','--niter', type = int, default = 100) | |
parser.add_argument('--plot', action = "store_true") | |
args = parser.parse_args() | |
# data loading and psf normalization | |
y = imread("Bars-G10-P15-stack.tif").astype(np.float32) | |
h = imread("PSF-Bars-stack.tif").astype(np.float32) | |
h /= np.sum(h) | |
# rl deconv | |
u = rl_deconv(y, h, niter = args.niter) | |
if args.plot: | |
# plot results | |
import matplotlib.pyplot as plt | |
_, axs = plt.subplots(2,2) | |
for _axs, dim in zip(axs,(0,2)): | |
for ax,_data in zip(_axs, (y,u)): | |
ax.imshow(np.max(_data,dim)) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment