Skip to content

Instantly share code, notes, and snippets.

@somecho
Created June 24, 2024 18:22
Show Gist options
  • Save somecho/0ce97a911dad2e6141daa0003730ea08 to your computer and use it in GitHub Desktop.
Save somecho/0ce97a911dad2e6141daa0003730ea08 to your computer and use it in GitHub Desktop.
WGPU Compute Shader Particles in Nannou
struct Buffer {
data: array<u32>
}
@group(0) @binding(0)
var<storage, read_write> ping: Buffer;
@group(0) @binding(1)
var<storage, read_write> pong: Buffer;
@group(0) @binding(2)
var<uniform> frame_num: u32;
@compute @workgroup_size(1, 1, 1)
fn main(@builtin(global_invocation_id) id: vec3<u32>){
let index: u32 = id.x;
if(frame_num / u32(60) % u32(2) == u32(0)){
pong.data[index] = (ping.data[index] + u32(1)) % u32(2);
} else {
ping.data[index] = (pong.data[index] + u32(1)) % u32(2);
}
return;
}
use nannou::prelude::*;
const NUM_X: usize = 500;
const NUM_Y: usize = 2;
const NUM_PARTICLES: usize = NUM_X * NUM_Y;
struct Render {
pipeline: wgpu::RenderPipeline,
vertex_buffer: wgpu::Buffer,
}
struct Compute {
pipeline: wgpu::ComputePipeline,
position_buffer: wgpu::Buffer,
}
struct Model {
render: Render,
compute: Compute,
bind_group: wgpu::BindGroup,
buffer_size: wgpu::BufferAddress,
uniform_bind_group: wgpu::BindGroup,
}
#[repr(C)]
#[derive(Clone, Copy)]
struct Vertex {
position: [f32; 2],
}
fn main() {
nannou::app(model).update(update).view(view).run();
}
fn model(app: &App) -> Model {
let w_id = app.new_window().size(720, 1024).build().unwrap();
let window = app.window(w_id).unwrap();
let device = window.device();
// shader modules
let cs_mod = device.create_shader_module(wgpu::include_wgsl!("shaders/cs.wgsl"));
let render_mod = device.create_shader_module(wgpu::include_wgsl!("shaders/render.wgsl"));
// create position and velocity buffer
let usage =
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC;
let win_dims: Vec2 = window.rect().wh();
let pos_data = create_position_data(win_dims);
let vel_data = create_velocity_data();
let pos_buffer = device.create_buffer_init(&wgpu::BufferInitDescriptor {
label: Some("Position Storage buffer"),
contents: unsafe { wgpu::bytes::from_slice(&pos_data) },
usage,
});
let vel_buffer = device.create_buffer_init(&wgpu::BufferInitDescriptor {
label: Some("Velocity Storage Buffer"),
contents: unsafe { wgpu::bytes::from_slice(&vel_data) },
usage,
});
// create static uniform buffer
let u_res: [f32; 2] = [window.rect().w(), window.rect().h()];
let u_res_buffer = device.create_buffer_init(&wgpu::BufferInitDescriptor {
label: Some("Static uniforms buffer"),
contents: unsafe { wgpu::bytes::from_slice(&[u_res]) },
usage: wgpu::BufferUsages::UNIFORM,
});
// create bind group from data
let visibility = wgpu::ShaderStages::COMPUTE;
let bind_group_layout = wgpu::BindGroupLayoutBuilder::new()
.storage_buffer(visibility, false, false)
.storage_buffer(visibility, false, false)
.build(device);
let buffer_size = (std::mem::size_of::<f32>() * NUM_PARTICLES * 2) as wgpu::BufferAddress;
let bind_group = wgpu::BindGroupBuilder::new()
.buffer_bytes(&pos_buffer, 0, wgpu::BufferSize::new(buffer_size))
.buffer_bytes(&vel_buffer, 0, wgpu::BufferSize::new(buffer_size))
.build(device, &bind_group_layout);
// create uniform bind group
let uniform_bind_group_layout = wgpu::BindGroupLayoutBuilder::new()
.uniform_buffer(
wgpu::ShaderStages::COMPUTE | wgpu::ShaderStages::VERTEX,
false,
)
.build(device);
let uniform_bind_group = wgpu::BindGroupBuilder::new()
.buffer_bytes(
&u_res_buffer,
0,
wgpu::BufferSize::new(std::mem::size_of_val(&u_res) as u64),
)
.build(device, &uniform_bind_group_layout);
// create pipeline layout
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Common Pipeline Layout"),
bind_group_layouts: &[&bind_group_layout, &uniform_bind_group_layout],
push_constant_ranges: &[],
});
let render_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Common Pipeline Layout"),
bind_group_layouts: &[&uniform_bind_group_layout],
push_constant_ranges: &[],
});
// create compute pipeline
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Compute Pipeline"),
layout: Some(&pipeline_layout),
module: &cs_mod,
entry_point: "main",
});
// create render pipeline
let vertex_buffer = device.create_buffer_init(&wgpu::BufferInitDescriptor {
label: Some("Vertex Data"),
contents: unsafe { wgpu::bytes::from_slice(&[pos_data]) },
usage: wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
});
let render_pipeline =
wgpu::RenderPipelineBuilder::from_layout(&render_pipeline_layout, &render_mod)
.vertex_entry_point("vertexMain")
.fragment_shader(&render_mod)
.fragment_entry_point("fragmentMain")
.sample_count(window.msaa_samples())
.color_format(Frame::TEXTURE_FORMAT)
.primitive_topology(wgpu::PrimitiveTopology::PointList)
.add_vertex_buffer::<f32>(&wgpu::vertex_attr_array![0 => Float32x2])
.build(device);
Model {
render: Render {
pipeline: render_pipeline,
vertex_buffer,
},
bind_group,
compute: Compute {
pipeline: compute_pipeline,
position_buffer: pos_buffer,
},
buffer_size,
uniform_bind_group,
}
}
fn update(app: &App, model: &mut Model, _update: Update) {
let window = app.main_window();
let device = window.device();
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Compute Encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Position Compute"),
});
pass.set_pipeline(&model.compute.pipeline);
pass.set_bind_group(0, &model.bind_group, &[]);
pass.set_bind_group(1, &model.uniform_bind_group, &[]);
pass.dispatch_workgroups(NUM_X as u32, NUM_Y as u32, 1);
}
encoder.copy_buffer_to_buffer(
&model.compute.position_buffer,
0,
&model.render.vertex_buffer,
0,
model.buffer_size,
);
window.queue().submit(Some(encoder.finish()));
}
fn view(app: &App, model: &Model, frame: Frame) {
{
let mut encoder = frame.command_encoder();
let mut pass = wgpu::RenderPassBuilder::new()
.color_attachment(frame.texture_view(), |c| c)
.begin(&mut encoder);
pass.set_pipeline(&model.render.pipeline);
pass.set_bind_group(0, &model.uniform_bind_group, &[]);
pass.set_vertex_buffer(0, model.render.vertex_buffer.slice(..));
pass.draw(0..NUM_PARTICLES as u32, 0..1);
}
let draw = app.draw();
let win = app.main_window().rect();
draw.text(app.fps().floor().to_string().as_str())
.x_y(win.left() + 40.0, win.top() - 40.0);
draw.to_frame(app, &frame).unwrap();
}
fn create_position_data(bounds: Vec2) -> [f32; NUM_PARTICLES * 2] {
let mut position_data: [f32; NUM_PARTICLES * 2] = [0.0; NUM_PARTICLES * 2];
for i in 0..NUM_PARTICLES {
position_data[i * 2] = random_f32() * bounds.x;
position_data[i * 2 + 1] = random_f32() * bounds.y;
}
position_data
}
fn create_velocity_data() -> [f32; NUM_PARTICLES * 2] {
let mut velocity_data: [f32; NUM_PARTICLES * 2] = [0.0; NUM_PARTICLES * 2];
for i in 0..NUM_PARTICLES {
velocity_data[i * 2] = random_f32() * 2.0 - 1.0;
velocity_data[i * 2 + 1] = random_f32() * 2.0 - 1.0;
}
velocity_data
}
@group(0) @binding(0)
var<storage> ping: array<u32>;
@group(0) @binding(1)
var<storage> pong: array<u32>;
@group(0) @binding(2)
var<uniform> frame_num: u32;
struct VertexOutput {
@builtin(position) pos: vec4f,
@location(0) id: u32,
};
@vertex
fn vertexMain(@location(0) pos: vec2f , @builtin(instance_index) instance: u32) -> VertexOutput {
let i = f32(instance);
let output = pos / 20.0 + vec2( i/10.0,0.0);
var vo: VertexOutput;
vo.pos = vec4f(output,0.0,1.0);
vo.id = instance;
return vo;
}
@fragment
fn fragmentMain(@location(0) id: u32) -> @location(0) vec4<f32> {
// let on = frame_num % u32(2) == u32(0) ? ping[id] : pong[id];
let on = select(ping[id],pong[id],frame_num / u32(60) % u32(2) == u32(0));
// let on = f32(frame_num / u32(60) % u32(2));
return vec4<f32>(f32(on), 0.0, 0.0, 1.0);
}
@somecho
Copy link
Author

somecho commented Jun 24, 2024

The only dependency is nannou 0.19.0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment