Skip to content

Instantly share code, notes, and snippets.

@jhurliman
Created October 4, 2022 01:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jhurliman/997cae2d942ed6b841d63326ed971b4f to your computer and use it in GitHub Desktop.
Save jhurliman/997cae2d942ed6b841d63326ed971b4f to your computer and use it in GitHub Desktop.
WebGL2 compute shaders using transform feedback
/**
* Matches against GLSL shader outputs.
*/
const VARYING_REGEX = /[^\w](?:varying|out)\s+\w+\s+(\w+)\s*;/g
/**
* Adds line numbers to a string with an optional starting offset.
*/
const lineNumbers = (source: string, offset = 0): string => source.replace(/^/gm, () => `${offset++}:`)
/**
* Gets the appropriate WebGL data type for a data view.
*/
const getDataType = (data: ArrayBufferView): number | null => {
switch (data.constructor) {
case Float32Array:
return 5126 // FLOAT
case Int8Array:
return 5120 // BYTE
case Int16Array:
return 5122 // SHORT
case Int32Array:
return 5124 // INT
case Uint8Array:
case Uint8ClampedArray:
return 5121 // UNSIGNED_BYTE
case Uint16Array:
return 5123 // UNSIGNED_SHORT
case Uint32Array:
return 5125 // UNSIGNED_INT
default:
return null
}
}
/**
* Represents compute input data.
*/
export interface WebGLComputeInput {
/**
* Input data view.
*/
data: ArrayBufferView
/**
* The size (per vertex) of the data array. Used to allocate data to each vertex.
*/
size: 1 | 2 | 3 | 4
}
/**
* WebGLCompute constructor parameters. Accepts a list of program inputs and compute shader source.
*/
export interface WebGLComputeOptions {
inputs: Record<string, WebGLComputeInput>
compute: string
}
/**
* Represents a compute result.
*/
export type WebGLComputeResult = Record<string, Float32Array>
/**
* Constructs a WebGL compute program via transform feedback. Can be used to compute and serialize data from the GPU.
*/
export class WebGLCompute {
readonly gl: WebGL2RenderingContext
readonly program: WebGLProgram
readonly VAO: WebGLVertexArrayObject
readonly transformFeedback: WebGLTransformFeedback
readonly buffers = new Map<string, WebGLBuffer>()
readonly containers = new Map<string, ArrayBufferView>()
private _length = 0
constructor(options: WebGLComputeOptions, gl = document.createElement('canvas').getContext('webgl2')!) {
this.gl = gl
// Parse outputs from shader source
const outputs = Array.from(options.compute.matchAll(VARYING_REGEX)).map(([, varying]) => varying)
// Compile shaders, configure output varyings
this.program = this.gl.createProgram()!
const vertexShader = this.gl.createShader(this.gl.VERTEX_SHADER)!
this.gl.shaderSource(vertexShader, options.compute)
this.gl.compileShader(vertexShader)
const error = this.gl.getShaderInfoLog(vertexShader)
if (error) throw `${error}\n${lineNumbers(options.compute)}`
const fragmentShader = this.gl.createShader(this.gl.FRAGMENT_SHADER)!
this.gl.shaderSource(fragmentShader, '#version 300 es\nprecision highp float;\nvoid main(){}')
this.gl.compileShader(fragmentShader)
this.gl.attachShader(this.program, vertexShader)
this.gl.attachShader(this.program, fragmentShader)
this.gl.transformFeedbackVaryings(this.program, outputs, this.gl.SEPARATE_ATTRIBS)
this.gl.linkProgram(this.program)
this.gl.detachShader(this.program, vertexShader)
this.gl.detachShader(this.program, fragmentShader)
this.gl.deleteShader(vertexShader)
this.gl.deleteShader(fragmentShader)
if (!this.gl.getProgramParameter(this.program, this.gl.LINK_STATUS)) {
throw this.gl.getProgramInfoLog(this.program)
}
// Init VAO state (input)
this.VAO = this.gl.createVertexArray()!
this.gl.bindVertexArray(this.VAO)
for (const name in options.inputs) {
const { data, size } = options.inputs[name]
const buffer = this.gl.createBuffer()!
this.gl.bindBuffer(this.gl.ARRAY_BUFFER, buffer)
this.gl.bufferData(this.gl.ARRAY_BUFFER, data, this.gl.STATIC_READ)
const location = this.gl.getAttribLocation(this.program, name)
this.gl.enableVertexAttribArray(location)
const dataType = getDataType(data)!
if (dataType === this.gl.INT || dataType === this.gl.UNSIGNED_INT) {
this.gl.vertexAttribIPointer(location, size, dataType, 0, 0)
} else {
this.gl.vertexAttribPointer(location, size, dataType, false, 0, 0)
}
this.buffers.set(name, buffer)
this._length = Math.max(this._length, (data as unknown as ArrayLike<number>).length / size)
}
this.gl.bindVertexArray(null)
// Init feedback state (output)
this.transformFeedback = this.gl.createTransformFeedback()!
this.gl.bindTransformFeedback(this.gl.TRANSFORM_FEEDBACK, this.transformFeedback)
for (const name of outputs) {
const data = new Float32Array(this._length)
this.containers.set(name, data)
const buffer = this.gl.createBuffer()!
this.gl.bindBuffer(this.gl.ARRAY_BUFFER, buffer)
this.gl.bufferData(this.gl.ARRAY_BUFFER, data, this.gl.STATIC_COPY)
this.gl.bindBuffer(this.gl.ARRAY_BUFFER, null)
this.gl.bindBufferBase(this.gl.TRANSFORM_FEEDBACK_BUFFER, this.containers.size - 1, buffer)
this.buffers.set(name, buffer)
}
this.gl.bindTransformFeedback(this.gl.TRANSFORM_FEEDBACK, null)
}
/**
* Runs and reads from the compute program.
*/
compute(): WebGLComputeResult {
// Run compute
this.gl.useProgram(this.program)
this.gl.bindVertexArray(this.VAO)
this.gl.bindTransformFeedback(this.gl.TRANSFORM_FEEDBACK, this.transformFeedback)
this.gl.enable(this.gl.RASTERIZER_DISCARD)
this.gl.beginTransformFeedback(this.gl.POINTS)
this.gl.drawArrays(this.gl.POINTS, 0, this._length)
this.gl.endTransformFeedback()
this.gl.useProgram(null)
this.gl.bindVertexArray(null)
this.gl.bindTransformFeedback(this.gl.TRANSFORM_FEEDBACK, null)
this.gl.disable(this.gl.RASTERIZER_DISCARD)
// Read output buffer data
return Array.from(this.containers).reduce((acc, [name, data]) => {
const buffer = this.buffers.get(name)!
this.gl.bindBuffer(this.gl.ARRAY_BUFFER, buffer)
this.gl.getBufferSubData(this.gl.ARRAY_BUFFER, 0, data)
this.gl.bindBuffer(this.gl.ARRAY_BUFFER, null)
return { ...acc, [name]: data }
}, {})
}
/**
* Disposes the compute pipeline from GPU memory.
*/
dispose(): void {
this.gl.deleteProgram(this.program)
this.gl.deleteVertexArray(this.VAO)
this.gl.deleteTransformFeedback(this.transformFeedback)
this.buffers.forEach((buffer) => this.gl.deleteBuffer(buffer))
}
}
const compute = new WebGLCompute({
inputs: {
source: {
data: new Float32Array([0, 1, 2, 3, 4]),
size: 1,
},
},
compute: /* glsl */ `#version 300 es
in float source;
out float result;
void main() {
result = source + float(gl_VertexID);
}
`,
})
// { result: Float32Array(5) [0, 2, 4, 6, 8] }
console.log(compute.compute())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment