Created
June 24, 2024 18:22
-
-
Save somecho/0ce97a911dad2e6141daa0003730ea08 to your computer and use it in GitHub Desktop.
WGPU Compute Shader Particles in Nannou
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct Buffer { | |
data: array<u32> | |
} | |
@group(0) @binding(0) | |
var<storage, read_write> ping: Buffer; | |
@group(0) @binding(1) | |
var<storage, read_write> pong: Buffer; | |
@group(0) @binding(2) | |
var<uniform> frame_num: u32; | |
@compute @workgroup_size(1, 1, 1) | |
fn main(@builtin(global_invocation_id) id: vec3<u32>){ | |
let index: u32 = id.x; | |
if(frame_num / u32(60) % u32(2) == u32(0)){ | |
pong.data[index] = (ping.data[index] + u32(1)) % u32(2); | |
} else { | |
ping.data[index] = (pong.data[index] + u32(1)) % u32(2); | |
} | |
return; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use nannou::prelude::*; | |
const NUM_X: usize = 500; | |
const NUM_Y: usize = 2; | |
const NUM_PARTICLES: usize = NUM_X * NUM_Y; | |
struct Render { | |
pipeline: wgpu::RenderPipeline, | |
vertex_buffer: wgpu::Buffer, | |
} | |
struct Compute { | |
pipeline: wgpu::ComputePipeline, | |
position_buffer: wgpu::Buffer, | |
} | |
struct Model { | |
render: Render, | |
compute: Compute, | |
bind_group: wgpu::BindGroup, | |
buffer_size: wgpu::BufferAddress, | |
uniform_bind_group: wgpu::BindGroup, | |
} | |
#[repr(C)] | |
#[derive(Clone, Copy)] | |
struct Vertex { | |
position: [f32; 2], | |
} | |
fn main() { | |
nannou::app(model).update(update).view(view).run(); | |
} | |
fn model(app: &App) -> Model { | |
let w_id = app.new_window().size(720, 1024).build().unwrap(); | |
let window = app.window(w_id).unwrap(); | |
let device = window.device(); | |
// shader modules | |
let cs_mod = device.create_shader_module(wgpu::include_wgsl!("shaders/cs.wgsl")); | |
let render_mod = device.create_shader_module(wgpu::include_wgsl!("shaders/render.wgsl")); | |
// create position and velocity buffer | |
let usage = | |
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC; | |
let win_dims: Vec2 = window.rect().wh(); | |
let pos_data = create_position_data(win_dims); | |
let vel_data = create_velocity_data(); | |
let pos_buffer = device.create_buffer_init(&wgpu::BufferInitDescriptor { | |
label: Some("Position Storage buffer"), | |
contents: unsafe { wgpu::bytes::from_slice(&pos_data) }, | |
usage, | |
}); | |
let vel_buffer = device.create_buffer_init(&wgpu::BufferInitDescriptor { | |
label: Some("Velocity Storage Buffer"), | |
contents: unsafe { wgpu::bytes::from_slice(&vel_data) }, | |
usage, | |
}); | |
// create static uniform buffer | |
let u_res: [f32; 2] = [window.rect().w(), window.rect().h()]; | |
let u_res_buffer = device.create_buffer_init(&wgpu::BufferInitDescriptor { | |
label: Some("Static uniforms buffer"), | |
contents: unsafe { wgpu::bytes::from_slice(&[u_res]) }, | |
usage: wgpu::BufferUsages::UNIFORM, | |
}); | |
// create bind group from data | |
let visibility = wgpu::ShaderStages::COMPUTE; | |
let bind_group_layout = wgpu::BindGroupLayoutBuilder::new() | |
.storage_buffer(visibility, false, false) | |
.storage_buffer(visibility, false, false) | |
.build(device); | |
let buffer_size = (std::mem::size_of::<f32>() * NUM_PARTICLES * 2) as wgpu::BufferAddress; | |
let bind_group = wgpu::BindGroupBuilder::new() | |
.buffer_bytes(&pos_buffer, 0, wgpu::BufferSize::new(buffer_size)) | |
.buffer_bytes(&vel_buffer, 0, wgpu::BufferSize::new(buffer_size)) | |
.build(device, &bind_group_layout); | |
// create uniform bind group | |
let uniform_bind_group_layout = wgpu::BindGroupLayoutBuilder::new() | |
.uniform_buffer( | |
wgpu::ShaderStages::COMPUTE | wgpu::ShaderStages::VERTEX, | |
false, | |
) | |
.build(device); | |
let uniform_bind_group = wgpu::BindGroupBuilder::new() | |
.buffer_bytes( | |
&u_res_buffer, | |
0, | |
wgpu::BufferSize::new(std::mem::size_of_val(&u_res) as u64), | |
) | |
.build(device, &uniform_bind_group_layout); | |
// create pipeline layout | |
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { | |
label: Some("Common Pipeline Layout"), | |
bind_group_layouts: &[&bind_group_layout, &uniform_bind_group_layout], | |
push_constant_ranges: &[], | |
}); | |
let render_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { | |
label: Some("Common Pipeline Layout"), | |
bind_group_layouts: &[&uniform_bind_group_layout], | |
push_constant_ranges: &[], | |
}); | |
// create compute pipeline | |
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { | |
label: Some("Compute Pipeline"), | |
layout: Some(&pipeline_layout), | |
module: &cs_mod, | |
entry_point: "main", | |
}); | |
// create render pipeline | |
let vertex_buffer = device.create_buffer_init(&wgpu::BufferInitDescriptor { | |
label: Some("Vertex Data"), | |
contents: unsafe { wgpu::bytes::from_slice(&[pos_data]) }, | |
usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST, | |
}); | |
let render_pipeline = | |
wgpu::RenderPipelineBuilder::from_layout(&render_pipeline_layout, &render_mod) | |
.vertex_entry_point("vertexMain") | |
.fragment_shader(&render_mod) | |
.fragment_entry_point("fragmentMain") | |
.sample_count(window.msaa_samples()) | |
.color_format(Frame::TEXTURE_FORMAT) | |
.primitive_topology(wgpu::PrimitiveTopology::PointList) | |
.add_vertex_buffer::<f32>(&wgpu::vertex_attr_array![0 => Float32x2]) | |
.build(device); | |
Model { | |
render: Render { | |
pipeline: render_pipeline, | |
vertex_buffer, | |
}, | |
bind_group, | |
compute: Compute { | |
pipeline: compute_pipeline, | |
position_buffer: pos_buffer, | |
}, | |
buffer_size, | |
uniform_bind_group, | |
} | |
} | |
fn update(app: &App, model: &mut Model, _update: Update) { | |
let window = app.main_window(); | |
let device = window.device(); | |
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { | |
label: Some("Compute Encoder"), | |
}); | |
{ | |
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { | |
label: Some("Position Compute"), | |
}); | |
pass.set_pipeline(&model.compute.pipeline); | |
pass.set_bind_group(0, &model.bind_group, &[]); | |
pass.set_bind_group(1, &model.uniform_bind_group, &[]); | |
pass.dispatch_workgroups(NUM_X as u32, NUM_Y as u32, 1); | |
} | |
encoder.copy_buffer_to_buffer( | |
&model.compute.position_buffer, | |
0, | |
&model.render.vertex_buffer, | |
0, | |
model.buffer_size, | |
); | |
window.queue().submit(Some(encoder.finish())); | |
} | |
fn view(app: &App, model: &Model, frame: Frame) { | |
{ | |
let mut encoder = frame.command_encoder(); | |
let mut pass = wgpu::RenderPassBuilder::new() | |
.color_attachment(frame.texture_view(), |c| c) | |
.begin(&mut encoder); | |
pass.set_pipeline(&model.render.pipeline); | |
pass.set_bind_group(0, &model.uniform_bind_group, &[]); | |
pass.set_vertex_buffer(0, model.render.vertex_buffer.slice(..)); | |
pass.draw(0..NUM_PARTICLES as u32, 0..1); | |
} | |
let draw = app.draw(); | |
let win = app.main_window().rect(); | |
draw.text(app.fps().floor().to_string().as_str()) | |
.x_y(win.left() + 40.0, win.top() - 40.0); | |
draw.to_frame(app, &frame).unwrap(); | |
} | |
fn create_position_data(bounds: Vec2) -> [f32; NUM_PARTICLES * 2] { | |
let mut position_data: [f32; NUM_PARTICLES * 2] = [0.0; NUM_PARTICLES * 2]; | |
for i in 0..NUM_PARTICLES { | |
position_data[i * 2] = random_f32() * bounds.x; | |
position_data[i * 2 + 1] = random_f32() * bounds.y; | |
} | |
position_data | |
} | |
fn create_velocity_data() -> [f32; NUM_PARTICLES * 2] { | |
let mut velocity_data: [f32; NUM_PARTICLES * 2] = [0.0; NUM_PARTICLES * 2]; | |
for i in 0..NUM_PARTICLES { | |
velocity_data[i * 2] = random_f32() * 2.0 - 1.0; | |
velocity_data[i * 2 + 1] = random_f32() * 2.0 - 1.0; | |
} | |
velocity_data | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@group(0) @binding(0) | |
var<storage> ping: array<u32>; | |
@group(0) @binding(1) | |
var<storage> pong: array<u32>; | |
@group(0) @binding(2) | |
var<uniform> frame_num: u32; | |
struct VertexOutput { | |
@builtin(position) pos: vec4f, | |
@location(0) id: u32, | |
}; | |
@vertex | |
fn vertexMain(@location(0) pos: vec2f , @builtin(instance_index) instance: u32) -> VertexOutput { | |
let i = f32(instance); | |
let output = pos / 20.0 + vec2( i/10.0,0.0); | |
var vo: VertexOutput; | |
vo.pos = vec4f(output,0.0,1.0); | |
vo.id = instance; | |
return vo; | |
} | |
@fragment | |
fn fragmentMain(@location(0) id: u32) -> @location(0) vec4<f32> { | |
// let on = frame_num % u32(2) == u32(0) ? ping[id] : pong[id]; | |
let on = select(ping[id],pong[id],frame_num / u32(60) % u32(2) == u32(0)); | |
// let on = f32(frame_num / u32(60) % u32(2)); | |
return vec4<f32>(f32(on), 0.0, 0.0, 1.0); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The only dependency is nannou 0.19.0.