Skip to content

Instantly share code, notes, and snippets.

@yumayanagisawa
Last active November 20, 2021 23:33
Show Gist options
  • Star 11 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save yumayanagisawa/742bf24d5edf1e73b971e14a2553ad4e to your computer and use it in GitHub Desktop.
Save yumayanagisawa/742bf24d5edf1e73b971e14a2553ad4e to your computer and use it in GitHub Desktop.
Unity | Compute Shader Particle System
#pragma kernel CSParticle
// Particle's data
struct Particle
{
float3 position;
float3 velocity;
float life;
};
// Particle's data, shared with the shader
RWStructuredBuffer<Particle> particleBuffer;
// Variables set from the CPU
float deltaTime;
float2 mousePosition;
float nrand(float2 uv)
{
return frac(sin(dot(uv, float2(12.9898, 78.233))) * 43758.5453);
}
uint rng_state;
uint rand_xorshift()
{
// Xorshift algorithm from George Marsaglia's paper
rng_state ^= (rng_state << 13);
rng_state ^= (rng_state >> 17);
rng_state ^= (rng_state << 5);
return rng_state;
}
[numthreads(256, 1, 1)]
void CSParticle(uint3 id : SV_DispatchThreadID)
{
// subtract the life based on deltaTime
particleBuffer[id.x].life -= deltaTime;
float3 delta = float3(mousePosition.xy, 3) - particleBuffer[id.x].position;
float3 dir = normalize(delta);
particleBuffer[id.x].velocity += dir;
particleBuffer[id.x].position += particleBuffer[id.x].velocity * deltaTime;
if (particleBuffer[id.x].life < 0)
{
// http://www.reedbeta.com/blog/quick-and-easy-gpu-random-numbers-in-d3d11/
rng_state = id.x;
float f0 = float(rand_xorshift()) * (1.0 / 4294967296.0) - 0.5;
float f1 = float(rand_xorshift()) * (1.0 / 4294967296.0) - 0.5;
float f2 = float(rand_xorshift()) * (1.0 / 4294967296.0) - 0.5;
float3 normalF3 = normalize(float3(f0, f1, f2)) * 0.8f;
normalF3 *= float(rand_xorshift()) * (1.0 / 4294967296.0);
particleBuffer[id.x].position = float3(normalF3.x + mousePosition.x, normalF3.y + mousePosition.y, normalF3.z + 3.0);
// reset the life of this particle
particleBuffer[id.x].life = 4;
particleBuffer[id.x].velocity = float3(0, 0,0);
}
}
Shader "Custom/Particle" {
SubShader {
Pass {
Tags{ "RenderType" = "Opaque" }
LOD 200
Blend SrcAlpha one
CGPROGRAM
// Physically based Standard lighting model, and enable shadows on all light types
#pragma vertex vert
#pragma fragment frag
#include "UnityCG.cginc"
// Use shader model 3.0 target, to get nicer looking lighting
#pragma target 5.0
struct Particle{
float3 position;
float3 velocity;
float life;
};
struct PS_INPUT{
float4 position : SV_POSITION;
float4 color : COLOR;
float life : LIFE;
};
// particles' data
StructuredBuffer<Particle> particleBuffer;
PS_INPUT vert(uint vertex_id : SV_VertexID, uint instance_id : SV_InstanceID)
{
PS_INPUT o = (PS_INPUT)0;
// Color
float life = particleBuffer[instance_id].life;
float lerpVal = life * 0.25f;
o.color = fixed4(1.0f - lerpVal+0.1, lerpVal+0.1, 1.0f, lerpVal);
// Position
o.position = UnityObjectToClipPos(float4(particleBuffer[instance_id].position, 1.0f));
return o;
}
float4 frag(PS_INPUT i) : COLOR
{
return i.color;
}
ENDCG
}
}
FallBack Off
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class RunCompute : MonoBehaviour {
private Vector2 cursorPos;
// struct
struct Particle
{
public Vector3 position;
public Vector3 velocity;
public float life;
}
/// <summary>
/// Size in octet of the Particle struct.
/// since float = 4 bytes...
/// 4 floats = 16 bytes
/// </summary>
//private const int SIZE_PARTICLE = 24;
private const int SIZE_PARTICLE = 28; // since property "life" is added...
/// <summary>
/// Number of Particle created in the system.
/// </summary>
private int particleCount = 1000000;
/// <summary>
/// Material used to draw the Particle on screen.
/// </summary>
public Material material;
/// <summary>
/// Compute shader used to update the Particles.
/// </summary>
public ComputeShader computeShader;
/// <summary>
/// Id of the kernel used.
/// </summary>
private int mComputeShaderKernelID;
/// <summary>
/// Buffer holding the Particles.
/// </summary>
ComputeBuffer particleBuffer;
/// <summary>
/// Number of particle per warp.
/// </summary>
private const int WARP_SIZE = 256; // TODO?
/// <summary>
/// Number of warp needed.
/// </summary>
private int mWarpCount; // TODO?
//public ComputeShader shader;
// Use this for initialization
void Start () {
InitComputeShader();
}
void InitComputeShader()
{
mWarpCount = Mathf.CeilToInt((float)particleCount / WARP_SIZE);
// initialize the particles
Particle[] particleArray = new Particle[particleCount];
for (int i = 0; i < particleCount; i++)
{
float x = Random.value * 2 - 1.0f;
float y = Random.value * 2 - 1.0f;
float z = Random.value * 2 - 1.0f;
Vector3 xyz = new Vector3(x, y, z);
xyz.Normalize();
xyz *= Random.value;
xyz *= 0.5f;
particleArray[i].position.x = xyz.x;
particleArray[i].position.y = xyz.y;
particleArray[i].position.z = xyz.z + 3;
particleArray[i].velocity.x = 0;
particleArray[i].velocity.y = 0;
particleArray[i].velocity.z = 0;
// Initial life value
particleArray[i].life = Random.value * 5.0f + 1.0f;
}
// create compute buffer
particleBuffer = new ComputeBuffer(particleCount, SIZE_PARTICLE);
particleBuffer.SetData(particleArray);
// find the id of the kernel
mComputeShaderKernelID = computeShader.FindKernel("CSParticle");
// bind the compute buffer to the shader and the compute shader
computeShader.SetBuffer(mComputeShaderKernelID, "particleBuffer", particleBuffer);
material.SetBuffer("particleBuffer", particleBuffer);
}
void OnRenderObject()
{
material.SetPass(0);
Graphics.DrawProcedural(MeshTopology.Points, 1, particleCount);
}
void OnDestroy()
{
if (particleBuffer != null)
particleBuffer.Release();
}
// Update is called once per frame
void Update () {
float[] mousePosition2D = { cursorPos.x, cursorPos.y };
// Send datas to the compute shader
computeShader.SetFloat("deltaTime", Time.deltaTime);
computeShader.SetFloats("mousePosition", mousePosition2D);
// Update the Particles
computeShader.Dispatch(mComputeShaderKernelID, mWarpCount, 1, 1);
}
void OnGUI()
{
Vector3 p = new Vector3();
Camera c = Camera.main;
Event e = Event.current;
Vector2 mousePos = new Vector2();
// Get the mouse position from Event.
// Note that the y position from Event is inverted.
mousePos.x = e.mousePosition.x;
mousePos.y = c.pixelHeight - e.mousePosition.y;
p = c.ScreenToWorldPoint(new Vector3(mousePos.x, mousePos.y, c.nearClipPlane + 14));// z = 3.
cursorPos.x = p.x;
cursorPos.y = p.y;
/*
GUILayout.BeginArea(new Rect(20, 20, 250, 120));
GUILayout.Label("Screen pixels: " + c.pixelWidth + ":" + c.pixelHeight);
GUILayout.Label("Mouse position: " + mousePos);
GUILayout.Label("World position: " + p.ToString("F3"));
GUILayout.EndArea();
*/
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment