Skip to content

Instantly share code, notes, and snippets.

@clemisch
Last active July 16, 2019 14:31
Show Gist options
  • Save clemisch/9ec752430b2eb65d9840d69fc8c058bb to your computer and use it in GitHub Desktop.
Save clemisch/9ec752430b2eb65d9840d69fc8c058bb to your computer and use it in GitHub Desktop.
RTK GPU transfer
import numpy as np
import itk
from itk import RTK as rtk
GPU_IMG = rtk.CudaImage[itk.F, 3]
CPU_IMG = rtk.Image[itk.F, 3]
def cpu_to_gpu_image(cpu_img, gpu_img=None):
if gpu_img is None:
gpu_img = GPU_IMG.New()
gpu_img.SetPixelContainer(cpu_img.GetPixelContainer())
gpu_img.CopyInformation(cpu_img)
gpu_img.SetBufferedRegion(cpu_img.GetBufferedRegion())
gpu_img.SetRequestedRegion(cpu_img.GetRequestedRegion())
return gpu_img
def forwardproject(vol, geometry, proj_shape, proj_origin, proj_spacing, img_spacing, img_origin):
"""\
vol : 3d array
volume to be forward projected
geometry : rtk geometry object
proj_shape : 3 element tuple/list
proj_shape = (num_angles, proj_shape_x, proj_shape_y)
proj_origin : 2 element tuple/list
proj_origin = (proj_origin_x, proj_origin_y)
proj_spacing : 2 element tuple/list
proj_spacing = (proj_spacing_x, proj_spacing_y)
img_spacing : 3 element tuple/list
img_origin : 3 element tuple_list
All arguments in C order indexing. Will be converted to ITK indexing inside function.
"""
assert vol.dtype == np.float32
assert vol.flags.c_contiguous
out = np.zeros(proj_shape, dtype=np.float32)
proj_img = itk.GetImageViewFromArray(out)
proj_img.SetOrigin(proj_origin[::-1] + [0]) # [::-1] for indexing of ITK; + [0] for RTK convention
proj_img.SetSpacing(proj_spacing[::-1] + [1]) # [::-1] for indexing of ITK; + [1] for RTK convention
vol_img = itk.GetImageViewFromArray(vol)
vol_img.SetOrigin(img_origin[::-1])
vol_img.SetSpacing(img_spacing[::-1])
# copy to gpu
proj_img_gpu = cpu_to_gpu_image(proj_img)
vol_img_gpu = cpu_to_gpu_image(vol_img)
man = proj_img_gpu.GetCudaDataManager()
fwd = rtk.CudaForwardProjectionImageFilter[GPU_IMG].New()
fwd.SetGeometry(geometry)
fwd.SetInput(0, proj_img_gpu)
fwd.SetInput(1, vol_img_gpu)
fwd.Update()
# copy data to cpu
man.SetCPUBufferPointer(proj_img.GetBufferPointer())
man.UpdateCPUBuffer()
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment