Skip to content

Instantly share code, notes, and snippets.

@greggman
Last active June 10, 2024 05:09
Show Gist options
  • Save greggman/d84a4c96c75091f50f34266d714a99d5 to your computer and use it in GitHub Desktop.
Save greggman/d84a4c96c75091f50f34266d714a99d5 to your computer and use it in GitHub Desktop.
WebGPU Verlet Integration
@import url(https://webgpufundamentals.org/webgpu/resources/webgpu-lesson.css);
html, body {
margin: 0; /* remove the default margin */
height: 100%; /* make the html,body fill the page */
}
canvas {
display: block; /* make the canvas act like a block */
width: 100%; /* make the canvas fill its container */
height: 100%;
}
:root {
--bg-color: #fff;
}
@media (prefers-color-scheme: dark) {
:root {
--bg-color: #000;
}
}
canvas {
background-color: var(--bg-color);
}
#info {
position: absolute;
left: 0;
top: 0;
padding: 0.5em;
margin: 0;
background-color: rgba(0, 0, 0, 0.8);
color: white;
min-width: 8em;
}
<canvas></canvas>
<pre id="info"></pre>
// WebGPU Optimization - None
// from https://webgpufundamentals.org/webgpu/webgpu-optimization-none.html
import GUI from 'https://webgpufundamentals.org/3rdparty/muigui-0.x.module.js';
import {mat4, mat3, vec3} from 'https://webgpufundamentals.org/3rdparty/wgpu-matrix.module.js';
// see https://webgpufundamentals.org/webgpu/lessons/webgpu-timing.html
import TimingHelper from 'https://webgpufundamentals.org/webgpu/resources/js/timing-helper.js';
// see https://webgpufundamentals.org/webgpu/lessons/webgpu-timing.html
import RollingAverage from 'https://webgpufundamentals.org/webgpu/resources/js/rolling-average.js';
import {
makeShaderDataDefinitions,
makeStructuredView,
} from 'https://greggman.github.io/webgpu-utils/dist/0.x/webgpu-utils.module.js';
class VerletIntegrator {
constructor(positions, prevPositions, forces, mass) {
this.positions = positions;
this.prevPositions = prevPositions;
this.forces = forces;
this.mass = mass;
}
integrateVerlet(dt) {
const count = this.positions.length / 4;
for (let i = 0; i < count; i++) {
const x = this.positions[i * 4 + 0];
const y = this.positions[i * 4 + 1];
const z = this.positions[i * 4 + 2];
const dx = x - this.prevPositions[i * 4 + 0];
const dy = y - this.prevPositions[i * 4 + 1];
const dz = z - this.prevPositions[i * 4 + 2];
const ax = this.forces[i * 4 + 0] / this.mass;
const ay = this.forces[i * 4 + 1] / this.mass;
const az = this.forces[i * 4 + 2] / this.mass;
this.prevPositions[i * 4 + 0] = x;
this.prevPositions[i * 4 + 1] = y;
this.prevPositions[i * 4 + 2] = z;
this.positions[i * 4 + 0] = x + dx + ax * dt * dt;
this.positions[i * 4 + 1] = y + dy + ay * dt * dt;
this.positions[i * 4 + 2] = z + dz + az * dt * dt;
}
}
}
const fpsAverage = new RollingAverage();
const jsAverage = new RollingAverage();
const drawAverage = new RollingAverage();
const cpAverage = new RollingAverage();
const mathAverage = new RollingAverage();
const cssColorToRGBA8 = (() => {
const canvas = new OffscreenCanvas(1, 1);
const ctx = canvas.getContext('2d', {willReadFrequently: true});
return cssColor => {
ctx.clearRect(0, 0, 1, 1);
ctx.fillStyle = cssColor;
ctx.fillRect(0, 0, 1, 1);
return Array.from(ctx.getImageData(0, 0, 1, 1).data);
};
})();
const hsl = (h, s, l) => `hsl(${h * 360 | 0}, ${s * 100}%, ${l * 100 | 0}%)`;
const cssColorToRGBA = cssColor => cssColorToRGBA8(cssColor).map(v => v / 255);
const hslToRGBA = (h, s, l) => cssColorToRGBA(hsl(h, s, l));
// multiply all elements of an array
const arrayProd = arr => arr.reduce((a, b) => a * b);
// Returns a random number between min and max.
// If min and max are not specified, returns 0 to 1
// If max is not specified, return 0 to min.
function rand(min, max) {
if (min === undefined) {
max = 1;
min = 0;
} else if (max === undefined) {
max = min;
min = 0;
}
return Math.random() * (max - min) + min;
}
// Selects a random array element
const randomArrayElement = arr => arr[Math.random() * arr.length | 0];
async function main() {
const adapter = await navigator.gpu?.requestAdapter();
const canTimestamp = adapter.features.has('timestamp-query');
const device = await adapter?.requestDevice({
requiredFeatures: [
...(canTimestamp ? ['timestamp-query'] : []),
],
});
if (!device) {
fail('could not init WebGPU');
}
const timingHelper = new TimingHelper(device);
const cpTimingHelper = new TimingHelper(device);
const infoElem = document.querySelector('#info');
// Get a WebGPU context from the canvas and configure it
const canvas = document.querySelector('canvas');
const context = canvas.getContext('webgpu');
const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
context.configure({
device,
format: presentationFormat,
alphaMode: 'premultiplied',
});
const workgroupSize = [8, 8, 4];
const dispatchCount = [25, 25, 25];
const numThreadsPerWorkgroup = arrayProd(workgroupSize);
const cpCode = `
@group(0) @binding(0) var<storage, read_write> positions: array<vec3<f32>>;
@group(0) @binding(1) var<storage, read_write> prevPositions: array<vec3<f32>>;
@group(0) @binding(2) var<storage, read> forces: array<vec3<f32>>;
@group(0) @binding(3) var<uniform> mass: f32;
@group(0) @binding(4) var<uniform> dt: f32;
@compute @workgroup_size(${workgroupSize})
fn main(
@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(local_invocation_index) local_invocation_index: u32,
@builtin(num_workgroups) num_workgroups: vec3<u32>
){
// workgroup_index is similar to local_invocation_index except for
// workgroups, not threads inside a workgroup.
// It is not a builtin so we compute it ourselves.
let workgroup_index =
workgroup_id.x +
workgroup_id.y * num_workgroups.x +
workgroup_id.z * num_workgroups.x * num_workgroups.y;
// global_invocation_index is like local_invocation_index
// except linear across all invocations across all dispatched
// workgroups. It is not a builtin so we compute it ourselves.
let i =
workgroup_index * ${numThreadsPerWorkgroup} +
local_invocation_index;
if (i < arrayLength(&positions)) {
let dPos = positions[i] - prevPositions[i];
let a = forces[i] / mass;
prevPositions[i] = positions[i];
positions[i] = positions[i] + dPos + a * dt * dt;
}
}
`;
const cpModule = device.createShaderModule({
code: cpCode,
});
const cpPipeline = device.createComputePipeline({
layout: 'auto',
compute: {
module: cpModule,
},
});
const numParticles = 2000 * 2000;
const cpDefs = makeShaderDataDefinitions(cpCode);
const positionsValues = makeStructuredView(cpDefs.storages.positions, new ArrayBuffer(numParticles * 4 * 4));
const prevPositionsValues = makeStructuredView(cpDefs.storages.prevPositions, new ArrayBuffer(numParticles * 4 * 4));
const forcesValues = makeStructuredView(cpDefs.storages.forces, new ArrayBuffer(numParticles * 4 * 4));
const massValues = makeStructuredView(cpDefs.uniforms.mass);
const dtValues = makeStructuredView(cpDefs.uniforms.dt);
const mass = 0.1;
const integrator = new VerletIntegrator(
positionsValues.views,
prevPositionsValues.views,
forcesValues.views,
mass);
const positionsBuffer = device.createBuffer({
size: positionsValues.arrayBuffer.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const prevPositionsBuffer = device.createBuffer({
size: prevPositionsValues.arrayBuffer.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const forcesBuffer = device.createBuffer({
size: forcesValues.arrayBuffer.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
const massBuffer = device.createBuffer({
size: massValues.arrayBuffer.byteLength,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
const dtBuffer = device.createBuffer({
size: dtValues.arrayBuffer.byteLength,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
for (let i = 0; i < numParticles; ++i) {
const x = rand(-1, 1);
const y = rand(-1, 1);
positionsValues.views[i * 4 + 0] = x;
positionsValues.views[i * 4 + 1] = y;
prevPositionsValues.views[i * 4 + 0] = x + rand(-0.001, 0.001);
prevPositionsValues.views[i * 4 + 1] = y + rand(-0.001, 0.001);;
forcesValues.views[i * 4 + 0] = rand(-0.001, 0.001);
forcesValues.views[i * 4 + 1] = rand(-0.001, 0.001);
}
massValues.set(mass);
device.queue.writeBuffer(positionsBuffer, 0, positionsValues.arrayBuffer);
device.queue.writeBuffer(prevPositionsBuffer, 0, prevPositionsValues.arrayBuffer);
device.queue.writeBuffer(forcesBuffer, 0, forcesValues.arrayBuffer);
device.queue.writeBuffer(massBuffer, 0, massValues.arrayBuffer);
const cpBindGroup = device.createBindGroup({
layout: cpPipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: positionsBuffer } },
{ binding: 1, resource: { buffer: prevPositionsBuffer } },
{ binding: 2, resource: { buffer: forcesBuffer } },
{ binding: 3, resource: { buffer: massBuffer } },
{ binding: 4, resource: { buffer: dtBuffer } },
],
});
const code = `
struct Uniforms {
resolution: vec2f,
size: f32
};
struct VSOutput {
@builtin(position) position: vec4f,
};
@group(0) @binding(0) var<storage, read> positions: array<vec3<f32>>;
@group(0) @binding(1) var<uniform> uni: Uniforms;
@vertex fn vs(
@builtin(vertex_index) vertNdx: u32,
@builtin(instance_index) instNdx: u32,
) -> VSOutput {
let points = array(
vec2f(-1, -1),
vec2f( 1, -1),
vec2f(-1, 1),
vec2f(-1, 1),
vec2f( 1, -1),
vec2f( 1, 1),
);
var vsOut: VSOutput;
let pos = points[vertNdx];
vsOut.position = vec4f(positions[instNdx].xy + pos * uni.size / uni.resolution, 0, 1);
return vsOut;
}
@fragment fn fs(vsOut: VSOutput) -> @location(0) vec4f {
return vec4f(1, 1, 0, 1); // yellow
}
`;
const module = device.createShaderModule({ code });
const defs = makeShaderDataDefinitions(code);
const uniformsValues = makeStructuredView(defs.uniforms.uni);
const uniformsBuffer = device.createBuffer({
size: uniformsValues.arrayBuffer.byteLength,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
const pipeline = device.createRenderPipeline({
label: 'draw particles',
layout: 'auto',
vertex: {
module,
},
fragment: {
module,
targets: [{ format: presentationFormat }],
},
});
const bindGroup = device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: positionsBuffer } },
{ binding: 1, resource: { buffer: uniformsBuffer } },
],
});
const renderPassDescriptor = {
label: 'our basic canvas renderPass',
colorAttachments: [
{
// view: <- to be filled out when we render
clearValue: [0.3, 0.3, 0.3, 1],
loadOp: 'clear',
storeOp: 'store',
},
],
};
const settings = {
timeJS: false,
};
const gui = new GUI();
gui.add(settings, 'timeJS');
const canvasToSizeMap = new WeakMap();
let then = 0;
let frameCount = 0;
function render(time) {
++frameCount;
time *= 0.001; // convert to seconds
const deltaTime = time - then;
then = time;
const startTimeMs = performance.now();
let width = 1;
let height = 1;
const entry = canvasToSizeMap.get(canvas);
if (entry) {
width = Math.max(1, Math.min(entry.contentBoxSize[0].inlineSize, device.limits.maxTextureDimension2D));
height = Math.max(1, Math.min(entry.contentBoxSize[0].blockSize, device.limits.maxTextureDimension2D));
}
if (canvas.width !== width || canvas.height !== height) {
canvas.width = width;
canvas.height = height;
}
// Get the current texture from the canvas context and
// set it as the texture to render to.
const canvasTexture = context.getCurrentTexture();
renderPassDescriptor.colorAttachments[0].view = canvasTexture.createView();
const encoder = device.createCommandEncoder();
{
dtValues.set(deltaTime);
device.queue.writeBuffer(dtBuffer, 0, dtValues.arrayBuffer);
const pass = cpTimingHelper.beginComputePass(encoder);
pass.setPipeline(cpPipeline);
pass.setBindGroup(0, cpBindGroup);
pass.dispatchWorkgroups(...dispatchCount);
pass.end();
}
{
uniformsValues.set({
resolution: [width, height],
size: 4,
});
device.queue.writeBuffer(uniformsBuffer, 0, uniformsValues.arrayBuffer);
const pass = timingHelper.beginRenderPass(encoder, renderPassDescriptor);
pass.setPipeline(pipeline);
pass.setBindGroup(0, bindGroup);
// 4 million particles is too many
pass.draw(6, Math.min(10000, numParticles));
pass.end();
}
const mathStartMs = performance.now();
if (settings.timeJS) {
integrator.integrateVerlet(deltaTime);
}
const mathElapsedTimeMs = performance.now() - mathStartMs;
const commandBuffer = encoder.finish();
device.queue.submit([commandBuffer]);
timingHelper.getResult().then(gpuTime => {
drawAverage.addSample(gpuTime / 1000);
});
cpTimingHelper.getResult().then(time => {
cpAverage.addSample(time / 1000);
});
const elapsedTimeMs = performance.now() - startTimeMs;
fpsAverage.addSample(1 / deltaTime);
jsAverage.addSample(elapsedTimeMs);
mathAverage.addSample(mathElapsedTimeMs);
infoElem.textContent = `\
js : ${jsAverage.get().toFixed(1)}ms
math : ${mathAverage.get().toFixed(1)}ms
fps : ${fpsAverage.get().toFixed(0)}
draw : ${canTimestamp ? `${(drawAverage.get() / 1000).toFixed(1)}ms` : 'N/A'}
compute: ${canTimestamp ? `${(cpAverage.get() / 1000).toFixed(1)}ms` : 'N/A'}
`;
requestAnimationFrame(render);
}
requestAnimationFrame(render);
const observer = new ResizeObserver(entries => {
entries.forEach(e => canvasToSizeMap.set(e.target, e));
});
observer.observe(canvas);
}
function fail(msg) {
alert(msg);
}
main();
{"name":"WebGPU Verlet Integration ","settings":{},"filenames":["index.html","index.css","index.js"]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment