Skip to content

Instantly share code, notes, and snippets.

@apriori
Last active August 8, 2021 19:34
Show Gist options
  • Save apriori/7af629e0d1ba7eff3eb36c20bfb8c628 to your computer and use it in GitHub Desktop.
Save apriori/7af629e0d1ba7eff3eb36c20bfb8c628 to your computer and use it in GitHub Desktop.
Example for rust-gpu code generation, by parsing the spirv module only
#[spirv(compute(threads(64)))]
pub fn phase_diff(
#[spirv(global_invocation_id)] id: UVec3,
#[spirv(uniform, descriptor_set = 0, binding = 0)] image_params: &ImageParams,
#[spirv(storage_buffer, descriptor_set = 1, binding = 0)] i0: &[u8],
#[spirv(storage_buffer, descriptor_set = 1, binding = 1)] i1: &[u8],
#[spirv(storage_buffer, descriptor_set = 1, binding = 2)] i2: &[u8],
#[spirv(storage_buffer, descriptor_set = 1, binding = 3)] i3: &[u8],
#[spirv(storage_buffer, descriptor_set = 2, binding = 0)] phase: &mut [Vec2],
) {
}
pub struct PhaseDiffImpl {
__kernel_exec_size: usize,
image_params: wgpu_utils::TypedBuffer<self::types::Struct2>,
i0: wgpu_utils::TypedBuffer<u8>,
i1: wgpu_utils::TypedBuffer<u8>,
i2: wgpu_utils::TypedBuffer<u8>,
i3: wgpu_utils::TypedBuffer<u8>,
phase: wgpu_utils::TypedBuffer<glam::Vec2>,
pipeline: wgpu::ComputePipeline,
bindgroups: std::collections::BTreeMap<u32, wgpu::BindGroup>,
}
pub struct PhaseDiffOutput<'a> {
pub phase: &'a wgpu_utils::TypedBuffer<glam::Vec2>,
}
impl PhaseDiffImpl {
pub fn new(
device: &wgpu::Device,
__kernel_exec_size: usize,
i0_size: usize,
i1_size: usize,
i2_size: usize,
i3_size: usize,
phase_size: usize,
) -> Self {
let shader_binary = wgpu_utils::load_spirv_shader(env!("pmt.spv"));
let module = device.create_shader_module(&shader_binary);
use std::collections::BTreeMap;
fn insert_value<V>(
descriptor_set: u32,
binding: u32,
value: V,
map: &mut BTreeMap<u32, BTreeMap<u32, V>>,
) {
if !map.contains_key(&descriptor_set) {
map.insert(descriptor_set, BTreeMap::new());
}
let inner = map.get_mut(&descriptor_set).unwrap();
inner.insert(binding, value);
}
let mut buffers: BTreeMap<u32, BTreeMap<u32, &wgpu::Buffer>> = BTreeMap::new();
let mut bindgroup_layout_entries: BTreeMap<u32, BTreeMap<u32, wgpu::BindGroupLayoutEntry>> =
BTreeMap::new();
let mut bindgroup_entries: BTreeMap<u32, BTreeMap<u32, wgpu::BindGroupEntry>> =
BTreeMap::new();
let image_params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bla"),
size: (std::mem::size_of::<self::types::Struct2>() * 1usize) as u64,
usage: wgpu::BufferUsage::UNIFORM | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});
let image_params_bindgroup_entry = wgpu::BindGroupEntry {
binding: 0u32,
resource: image_params_buffer.as_entire_binding(),
};
insert_value(0u32, 0u32, &image_params_buffer, &mut buffers);
insert_value(
0u32,
0u32,
wgpu::BindGroupLayoutEntry {
binding: 0u32,
visibility: wgpu::ShaderStage::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
&mut bindgroup_layout_entries,
);
insert_value(
0u32,
0u32,
image_params_bindgroup_entry,
&mut bindgroup_entries,
);
let i0_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bla"),
size: (std::mem::size_of::<u8>() * i0_size) as u64,
usage: wgpu::BufferUsage::STORAGE | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});
let i0_bindgroup_entry = wgpu::BindGroupEntry {
binding: 0u32,
resource: i0_buffer.as_entire_binding(),
};
insert_value(1u32, 0u32, &i0_buffer, &mut buffers);
insert_value(
1u32,
0u32,
wgpu::BindGroupLayoutEntry {
binding: 0u32,
visibility: wgpu::ShaderStage::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
&mut bindgroup_layout_entries,
);
insert_value(1u32, 0u32, i0_bindgroup_entry, &mut bindgroup_entries);
let i1_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bla"),
size: (std::mem::size_of::<u8>() * i1_size) as u64,
usage: wgpu::BufferUsage::STORAGE | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});
let i1_bindgroup_entry = wgpu::BindGroupEntry {
binding: 1u32,
resource: i1_buffer.as_entire_binding(),
};
insert_value(1u32, 1u32, &i1_buffer, &mut buffers);
insert_value(
1u32,
1u32,
wgpu::BindGroupLayoutEntry {
binding: 1u32,
visibility: wgpu::ShaderStage::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
&mut bindgroup_layout_entries,
);
insert_value(1u32, 1u32, i1_bindgroup_entry, &mut bindgroup_entries);
let i2_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bla"),
size: (std::mem::size_of::<u8>() * i2_size) as u64,
usage: wgpu::BufferUsage::STORAGE | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});
let i2_bindgroup_entry = wgpu::BindGroupEntry {
binding: 2u32,
resource: i2_buffer.as_entire_binding(),
};
insert_value(1u32, 2u32, &i2_buffer, &mut buffers);
insert_value(
1u32,
2u32,
wgpu::BindGroupLayoutEntry {
binding: 2u32,
visibility: wgpu::ShaderStage::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
&mut bindgroup_layout_entries,
);
insert_value(1u32, 2u32, i2_bindgroup_entry, &mut bindgroup_entries);
let i3_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bla"),
size: (std::mem::size_of::<u8>() * i3_size) as u64,
usage: wgpu::BufferUsage::STORAGE | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});
let i3_bindgroup_entry = wgpu::BindGroupEntry {
binding: 3u32,
resource: i3_buffer.as_entire_binding(),
};
insert_value(1u32, 3u32, &i3_buffer, &mut buffers);
insert_value(
1u32,
3u32,
wgpu::BindGroupLayoutEntry {
binding: 3u32,
visibility: wgpu::ShaderStage::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
&mut bindgroup_layout_entries,
);
insert_value(1u32, 3u32, i3_bindgroup_entry, &mut bindgroup_entries);
let phase_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bla"),
size: (std::mem::size_of::<glam::Vec2>() * phase_size) as u64,
usage: wgpu::BufferUsage::STORAGE | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});
let phase_bindgroup_entry = wgpu::BindGroupEntry {
binding: 0u32,
resource: phase_buffer.as_entire_binding(),
};
insert_value(2u32, 0u32, &phase_buffer, &mut buffers);
insert_value(
2u32,
0u32,
wgpu::BindGroupLayoutEntry {
binding: 0u32,
visibility: wgpu::ShaderStage::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
&mut bindgroup_layout_entries,
);
insert_value(2u32, 0u32, phase_bindgroup_entry, &mut bindgroup_entries);
let mut layouts = BTreeMap::new();
let mut bindgroups = BTreeMap::new();
for descriptor_set in buffers.keys() {
let buf = buffers.get(descriptor_set).unwrap();
let bound_buffers = buffers.get(descriptor_set).unwrap().values();
let bindgroup_layouts = bindgroup_entries.get(descriptor_set).unwrap().values();
let bindgroup_layout_entries = bindgroup_layout_entries
.get(descriptor_set)
.unwrap()
.values();
let layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(&format!("{} Set {}", "phase_diff", descriptor_set)),
entries: bindgroup_layout_entries
.cloned()
.collect::<Vec<_>>()
.as_slice(),
});
let group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("{} Set {}", "phase_diff", descriptor_set)),
layout: &layout,
entries: bindgroup_layouts.cloned().collect::<Vec<_>>().as_slice(),
});
bindgroups.insert(*descriptor_set, group);
layouts.insert(*descriptor_set, layout);
}
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("pipelinelayout"),
bind_group_layouts: &layouts.values().collect::<Vec<_>>().as_slice(),
push_constant_ranges: &[],
});
log::info!("Create pipeline");
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("compute pipeline"),
layout: Some(&pipeline_layout),
module: &module,
entry_point: "phase_diff",
});
log::info!("Create pipeline done");
use wgpu_utils::BufferExt;
Self {
__kernel_exec_size,
image_params: image_params_buffer.as_typed(1usize),
i0: i0_buffer.as_typed(i0_size),
i1: i1_buffer.as_typed(i1_size),
i2: i2_buffer.as_typed(i2_size),
i3: i3_buffer.as_typed(i3_size),
phase: phase_buffer.as_typed(phase_size),
pipeline,
bindgroups,
}
}
pub async fn phase_diff<'a>(
&'a self,
device: &wgpu::Device,
queue: &wgpu::Queue,
image_params: self::types::Struct2,
i0: wgpu_utils::TypedBuffer<u8>,
i1: wgpu_utils::TypedBuffer<u8>,
i2: wgpu_utils::TypedBuffer<u8>,
i3: wgpu_utils::TypedBuffer<u8>,
) -> PhaseDiffOutput<'a> {
self.image_params
.copy_single_from_host(queue, &image_params);
i0.copy_to(device, queue, &self.i0);
i1.copy_to(device, queue, &self.i1);
i2.copy_to(device, queue, &self.i2);
i3.copy_to(device, queue, &self.i3);
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut cpass =
encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
cpass.set_pipeline(&self.pipeline);
for (k, v) in &self.bindgroups {
cpass.set_bind_group(*k, v, &[]);
}
cpass.dispatch(self.__kernel_exec_size as u32 / 64u32, 1, 1);
}
queue.submit(Some(encoder.finish()));
PhaseDiffOutput { phase: &self.phase }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment