Skip to content

Instantly share code, notes, and snippets.

@babon
Last active September 23, 2023 12:53
Show Gist options
  • Save babon/b6c0985c6f56935ee6d4251a087d526c to your computer and use it in GitHub Desktop.
Save babon/b6c0985c6f56935ee6d4251a087d526c to your computer and use it in GitHub Desktop.
Unity compute shader dispatch helper with better syntax
//Example usage:
shader.Dispatch("Advect", texWidth, texHeight, texDepth,
"_texture", texture,
"_separationPlane", Vector3.Scale(separationPlane.position, tscale),
"_advectFall", advectN == 0,
"_advectCapsuleCenter", Vector3.Scale(playerCapsule.transform.TransformPoint(playerCapsule.center), tscale));
//Looks like a mess but is actually performant and is zero allocations:
static Dictionary<string, int> stringToIntParamCache = new Dictionary<string, int>();
static Dictionary<(ComputeShader, string), int> stringToIntKernelCache = new Dictionary<(ComputeShader, string), int>();
static int GetK(ComputeShader s, string kernelName)
{
int k;
if (!stringToIntKernelCache.TryGetValue((s, kernelName), out k))
{
k = s.FindKernel(kernelName);
stringToIntKernelCache.Add((s, kernelName), k);
}
return k;
}
static void DoT<T>(ComputeShader s, int k, string n, T p)
{
int id;
if (!stringToIntParamCache.TryGetValue(n, out id))
{
id = Shader.PropertyToID(n);
stringToIntParamCache.Add(n, id);
}
Type t = typeof(T);
if (t == typeof(ComputeBuffer)) s.SetBuffer(k, id, UnsafeUtility.As<T, ComputeBuffer>(ref p));
else if (t == typeof(RenderTexture)) s.SetTexture(k, id, UnsafeUtility.As<T, RenderTexture>(ref p));
else if (t == typeof(Texture3D)) s.SetTexture(k, id, UnsafeUtility.As<T, Texture3D>(ref p));
else if (t == typeof(Matrix4x4[])) s.SetMatrixArray(id, UnsafeUtility.As<T, Matrix4x4[]>(ref p));
else if (t == typeof(Matrix4x4)) s.SetMatrix(id, UnsafeUtility.As<T, Matrix4x4>(ref p));
else if (t == typeof(Vector3)) s.SetVector(id, UnsafeUtility.As<T, Vector3>(ref p));
else if (t == typeof(Vector3Int)) s.SetVector(id, (Vector3)UnsafeUtility.As<T, Vector3Int>(ref p));
else if (t == typeof(Vector4)) s.SetVector(id, UnsafeUtility.As<T, Vector4>(ref p));
else if (t == typeof(bool)) s.SetBool(id, UnsafeUtility.As<T, bool>(ref p));
else if (t == typeof(int)) s.SetInt(id, UnsafeUtility.As<T, int>(ref p));
else if (t == typeof(float)) s.SetFloat(id, UnsafeUtility.As<T, float>(ref p));
else Debug.LogError(t + " not recognized");
}
public static void Dispatch(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided) { int k = GetK(s, kernelName); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6, T7>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6, T7, T8>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7, string n8, T8 p8) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); DoT(s, k, n8, p8); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6, T7, T8, T9>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7, string n8, T8 p8, string n9, T9 p9) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); DoT(s, k, n8, p8); DoT(s, k, n9, p9); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7, string n8, T8 p8, string n9, T9 p9, string n10, T10 p10) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); DoT(s, k, n8, p8); DoT(s, k, n9, p9); DoT(s, k, n10, p10); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7, string n8, T8 p8, string n9, T9 p9, string n10, T10 p10, string n11, T11 p11) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); DoT(s, k, n8, p8); DoT(s, k, n9, p9); DoT(s, k, n10, p10); DoT(s, k, n11, p11); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7, string n8, T8 p8, string n9, T9 p9, string n10, T10 p10, string n11, T11 p11, string n12, T12 p12) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); DoT(s, k, n8, p8); DoT(s, k, n9, p9); DoT(s, k, n10, p10); DoT(s, k, n11, p11); DoT(s, k, n12, p12); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7, string n8, T8 p8, string n9, T9 p9, string n10, T10 p10, string n11, T11 p11, string n12, T12 p12, string n13, T13 p13) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); DoT(s, k, n8, p8); DoT(s, k, n9, p9); DoT(s, k, n10, p10); DoT(s, k, n11, p11); DoT(s, k, n12, p12); DoT(s, k, n13, p13); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7, string n8, T8 p8, string n9, T9 p9, string n10, T10 p10, string n11, T11 p11, string n12, T12 p12, string n13, T13 p13, string n14, T14 p14) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); DoT(s, k, n8, p8); DoT(s, k, n9, p9); DoT(s, k, n10, p10); DoT(s, k, n11, p11); DoT(s, k, n12, p12); DoT(s, k, n13, p13); DoT(s, k, n14, p14); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7, string n8, T8 p8, string n9, T9 p9, string n10, T10 p10, string n11, T11 p11, string n12, T12 p12, string n13, T13 p13, string n14, T14 p14, string n15, T15 p15) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); DoT(s, k, n8, p8); DoT(s, k, n9, p9); DoT(s, k, n10, p10); DoT(s, k, n11, p11); DoT(s, k, n12, p12); DoT(s, k, n13, p13); DoT(s, k, n14, p14); DoT(s, k, n15, p15); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
public static void Dispatch<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16>(this ComputeShader s, string kernelName, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7, string n8, T8 p8,string n9, T9 p9, string n10, T10 p10, string n11, T11 p11, string n12, T12 p12, string n13, T13 p13, string n14, T14 p14, string n15, T15 p15, string n16, T16 p16) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); DoT(s, k, n8, p8); DoT(s, k, n9, p9); DoT(s, k, n10, p10); DoT(s, k, n11, p11); DoT(s, k, n12, p12); DoT(s, k, n13, p13); DoT(s, k, n14, p14); DoT(s, k, n15, p15); DoT(s, k, n16, p16); Dispatch(s, k, threadsXUndivided, threadsYUndivided, threadsZUndivided); }
static void Dispatch(this ComputeShader s, int k, int threadsXUndivided, int threadsYUndivided, int threadsZUndivided)
{
s.GetKernelThreadGroupSizes(k, out uint xSize, out uint ySize, out uint zSize);
//if (threadsXUndivided % xSize != 0 || threadsYUndivided % ySize != 0 || threadsZUndivided % zSize != 0) Debug.LogError("Not clean group num division");
s.Dispatch(k, Mathf.CeilToInt(threadsXUndivided / (float)xSize), Mathf.CeilToInt(threadsYUndivided / (float)ySize), Mathf.CeilToInt(threadsZUndivided / (float)zSize));
}
public static void DispatchIndirect(this ComputeShader s, string kernelName, ComputeBuffer groupNumsIn012) { int k = GetK(s, kernelName); DispatchIndirect(s, k, groupNumsIn012); }
public static void DispatchIndirect<T1>(this ComputeShader s, string kernelName, ComputeBuffer groupNumsIn012, string n1, T1 p1) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DispatchIndirect(s, k, groupNumsIn012); }
public static void DispatchIndirect<T1, T2>(this ComputeShader s, string kernelName, ComputeBuffer groupNumsIn012, string n1, T1 p1, string n2, T2 p2) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DispatchIndirect(s, k, groupNumsIn012); }
public static void DispatchIndirect<T1, T2, T3>(this ComputeShader s, string kernelName, ComputeBuffer groupNumsIn012, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DispatchIndirect(s, k, groupNumsIn012); }
public static void DispatchIndirect<T1, T2, T3, T4>(this ComputeShader s, string kernelName, ComputeBuffer groupNumsIn012, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DispatchIndirect(s, k, groupNumsIn012); }
public static void DispatchIndirect<T1, T2, T3, T4, T5>(this ComputeShader s, string kernelName, ComputeBuffer groupNumsIn012, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DispatchIndirect(s, k, groupNumsIn012); }
public static void DispatchIndirect<T1, T2, T3, T4, T5, T6>(this ComputeShader s, string kernelName, ComputeBuffer groupNumsIn012, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DispatchIndirect(s, k, groupNumsIn012); }
public static void DispatchIndirect<T1, T2, T3, T4, T5, T6, T7>(this ComputeShader s, string kernelName, ComputeBuffer groupNumsIn012, string n1, T1 p1, string n2, T2 p2, string n3, T3 p3, string n4, T4 p4, string n5, T5 p5, string n6, T6 p6, string n7, T7 p7) { int k = GetK(s, kernelName); DoT(s, k, n1, p1); DoT(s, k, n2, p2); DoT(s, k, n3, p3); DoT(s, k, n4, p4); DoT(s, k, n5, p5); DoT(s, k, n6, p6); DoT(s, k, n7, p7); DispatchIndirect(s, k, groupNumsIn012); }
static void DispatchIndirect(this ComputeShader s, int k, ComputeBuffer groupNumsIn012)
{
s.DispatchIndirect(k, groupNumsIn012);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment