Skip to content

Instantly share code, notes, and snippets.

@pieper
Created May 21, 2022 20:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pieper/60bf90c6b1323c3eb3fd07d5e50746d0 to your computer and use it in GitHub Desktop.
Save pieper/60bf90c6b1323c3eb3fd07d5e50746d0 to your computer and use it in GitHub Desktop.
Slicer with python WebGPU compute shader
"""
Install wgpu as described here: https://github.com/pygfx/wgpu-py
Tested with Slicer 5.0.2 and wgpu 537c3eab68e9eef77681fc5545532380df26d8cc (basically 0.8.1)
exec(open("./slicer-compute.py").read())
"""
import numpy
import wgpu
import wgpu.backends.rs # Select backend
from wgpu.utils import compute_with_buffers # Convenience function
try:
mrHead = slicer.util.getNode("MRHead")
except slicer.util.MRMLNodeNotFoundException:
import SampleData
mrHead = SampleData.SampleDataLogic().downloadMRHead()
headArray = slicer.util.arrayFromVolume(mrHead)
sliceSize = headArray.shape[1] * headArray.shape[2]
headIntArray = headArray.astype('int32')
bufferSize = headArray.flatten().shape[0]
shader = """
@group(0) @binding(0)
var<storage,read> data1: array<i32>;
@group(0) @binding(1)
var<storage,read_write> data2: array<i32>;
@stage(compute)
@workgroup_size(1)
fn main(@builtin(global_invocation_id) index: vec3<u32>) {
let i: u32 = index.x * @@SLICE_SIZE@@ + index.y * @@ROW_SIZE@@ + index.z;
data2[i] = -1 * data1[i];
}
"""
shader = shader.replace("@@SLICE_SIZE@@", str(sliceSize)+"u")
shader = shader.replace("@@ROW_SIZE@@", str(headArray.shape[2])+"u")
print("computing...")
out = compute_with_buffers( input_arrays={0: headIntArray},
output_arrays={1: (bufferSize, "i")},
shader=shader,
n=headArray.shape )
print("done")
# `out` is a dict matching the output types
# Select data from buffer at binding 1
resultArray = numpy.array(out[1])
assert resultArray.mean() == -1 * headArray.mean()
headArray[:] = resultArray.astype('int16').reshape(headArray.shape)
slicer.util.arrayFromVolumeModified(mrHead)
mrHead.GetDisplayNode().SetAutoWindowLevel(False)
mrHead.GetDisplayNode().SetAutoWindowLevel(True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment