Skip to content

Instantly share code, notes, and snippets.

@SabinT
Last active February 7, 2023 22:34
Show Gist options
  • Save SabinT/deed2a263c3fe4f99fdaa39bdb19bcc2 to your computer and use it in GitHub Desktop.
Save SabinT/deed2a263c3fe4f99fdaa39bdb19bcc2 to your computer and use it in GitHub Desktop.
Unity3D: Generic template to pass options to a compute shader and dispatch; avoids common boilerplate of Shader.PropertyToID and shader.SetInt etc
using EasyButtons;
namespace Lumic.Compute
{
using System;
using System.Linq;
using System.Reflection;
using System.Collections.Generic;
using UnityEngine;
[Serializable]
public class ExampleComputeShaderParams
{
[PassToShader(Name = "_Number")] public float Num;
[PassToShader] public Vector4 Vec;
}
public class ExampleComputeShaderDriver : GenericComputeShaderDriver<ExampleComputeShaderParams>
{
public void Update()
{
this.ApplyPropertiesAndDispatch("TestKernel");
}
}
public class GenericComputeShaderDriver<T> : MonoBehaviour where T : new()
{
public ComputeShader ComputeShader;
[Tooltip(
"Set this to match the resolution of the buffers divided by 'numthreads' in the compute shader kernels.")]
public Vector3Int ThreadGroupSize = Vector3Int.one;
public T Options = new T();
/// <summary>
/// A dictionary of decorated fields from <see cref="Options"/> that need to be
/// passed to the compute shader. The value is the "ShaderId".
/// </summary>
private Dictionary<FieldInfo, int> fieldsWithShaderIds = null;
// Start is called before the first frame update
protected virtual void Start()
{
if (this.ComputeShader == null)
{
Debug.Log("No compute shader assigned!");
return;
}
this.BuildShaderAttributesMapIfNeeded();
}
private void BuildShaderAttributesMapIfNeeded()
{
if (this.fieldsWithShaderIds != null)
{
return;
}
// Find the fields inside "Options" that are decorated
var fields = this.Options.GetType()
.GetFields(
BindingFlags.Public |
BindingFlags.NonPublic |
BindingFlags.Instance)
.Where(
field => Attribute.IsDefined(field, typeof(PassToShader)));
this.fieldsWithShaderIds = new Dictionary<FieldInfo, int>();
foreach (FieldInfo field in fields)
{
PassToShader attr = field.GetCustomAttribute<PassToShader>();
string shaderVariableName = !string.IsNullOrWhiteSpace(attr.Name)
? attr.Name
: field.Name;
fieldsWithShaderIds.Add(field, Shader.PropertyToID(shaderVariableName));
}
}
[Button]
public void ApplyProperties(string kernel)
{
this.BuildShaderAttributesMapIfNeeded();
if (this.ComputeShader != null)
{
int kernelIndex = this.ComputeShader.FindKernel(kernel);
foreach (KeyValuePair<FieldInfo, int> pair in fieldsWithShaderIds)
{
int nameId = pair.Value;
object value = pair.Key.GetValue(this.Options);
if (value == null)
{
continue;
}
switch (value)
{
case bool b:
this.ComputeShader.SetBool(nameId, b);
break;
case int i:
this.ComputeShader.SetInt(nameId, i);
break;
case float f:
this.ComputeShader.SetFloat(nameId, f);
break;
case Vector2 v:
this.ComputeShader.SetVector(nameId, v);
break;
case Vector2Int v:
this.ComputeShader.SetInts(nameId, v.x, v.y);
break;
case Vector3 v:
this.ComputeShader.SetVector(nameId, v);
break;
case Vector4 v:
this.ComputeShader.SetVector(nameId, v);
break;
case Vector4[] vectors:
this.ComputeShader.SetVectorArray(nameId, vectors);
break;
case Color c:
this.ComputeShader.SetVector(nameId, new Vector4(c.r, c.g, c.b, c.a));
break;
case Color[] colors:
this.ComputeShader.SetVectorArray(
nameId,
colors.Select(c => new Vector4(c.r, c.g, c.b, c.a)).ToArray());
break;
case RenderTexture t:
this.ComputeShader.SetTexture(kernelIndex, nameId, t);
break;
case ComputeBuffer b:
this.ComputeShader.SetBuffer(kernelIndex, nameId, b);
break;
case Texture t:
this.ComputeShader.SetTexture(kernelIndex, nameId, t);
break;
// Add more cases here to support more types
default:
// Unsupported type
Debug.LogError(
$"Not passing unsupported type to Compute Shader: " +
$"{pair.Key.Name} of type {pair.Key.FieldType}");
break;
}
}
}
}
[Button]
public void AutoCalculateThreadGroups2D(
string kernel,
int x = 1,
int y = 1,
int z = 1)
{
int ki = this.ComputeShader.FindKernel(kernel);
this.ComputeShader.GetKernelThreadGroupSizes(ki, out uint kx, out uint ky, out uint kz);
this.ThreadGroupSize =
new Vector3Int(
x / (int) kx,
y / (int) ky,
z / (int) kz
);
}
[Button]
public void Dispatch(string kernel)
{
this.ApplyPropertiesAndDispatch(kernel);
}
protected void ApplyPropertiesAndDispatch(string kernel, Vector3Int? threadGroupSizeOverride = null)
{
this.ApplyProperties(kernel);
this.Dispatch(kernel, threadGroupSizeOverride);
}
protected void Dispatch(string kernel, Vector3Int? threadGroupSizeOverride = null)
{
if (this.ComputeShader != null)
{
int kernelIndex = this.ComputeShader.FindKernel(kernel);
Vector3Int tgSize = threadGroupSizeOverride ?? this.ThreadGroupSize;
this.ComputeShader.Dispatch(
kernelIndex,
tgSize.x,
tgSize.y,
tgSize.z);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment