Skip to content

Instantly share code, notes, and snippets.

@AtheMathmo
Last active August 14, 2020 17:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AtheMathmo/0fac0e80833691701f44f6378cbebd29 to your computer and use it in GitHub Desktop.
Save AtheMathmo/0fac0e80833691701f44f6378cbebd29 to your computer and use it in GitHub Desktop.
Code for procedurally generated lightning bolts in unity
using UnityEngine;
public class LightningRRTGrower : RRTGrower
{
Vector3 sampleSphereOrigin;
float pathLength;
float sampleRadius = 3f;
public float SampleRadius { get => sampleRadius; set => sampleRadius = value; }
public LightningRRTGrower(int maxIters, Vector3 start, Vector3 goal) : base(maxIters, start, goal)
{
sampleSphereOrigin = (goal + start) / 2;
pathLength = (goal - start).magnitude / 2;
}
protected override Vector3 SamplePoint(Vector3 goal)
{
// Sample inside an ellipse that spans the space between the goal and start point
// Note: This gives an unbiased sample (doesn't favour the goal)
// and more efficient version could be designed.
Quaternion lookAt = Quaternion.FromToRotation(Vector3.forward, (goal - start).normalized);
return sampleSphereOrigin + lookAt * Vector3.Scale(Random.insideUnitSphere, new Vector3(sampleRadius, sampleRadius, pathLength));
}
protected override bool TryToExtend(Vector3 closestPos, Vector3 target, out VectorNode extension)
{
// Only allow extension if point is below
if (closestPos.y > target.y) {
extension = new VectorNode(target);
return true;
} else {
extension = null;
return false;
}
}
}
using System.Collections;
using System.Collections.Generic;
public static class ListPool<T>
{
static Stack<List<T>> stack = new Stack<List<T>>();
public static List<T> Get()
{
if (stack.Count > 0) {
return stack.Pop();
}
return new List<T>();
}
public static void Add(List<T> list) {
list.Clear();
stack.Push(list);
}
}
using UnityEngine;
/// Base class for building RRTs
public abstract class RRTGrower
{
protected Vector3 start;
protected Vector3 goal;
protected int maxIters;
protected float goalTolerance = 1e-1f;
public float GoalTolerance { get => goalTolerance; set => goalTolerance = value; }
public RRTGrower(int maxIters, Vector3 start, Vector3 goal)
{
this.maxIters = maxIters;
this.start = start;
this.goal = goal;
}
protected abstract Vector3 SamplePoint(Vector3 goal);
protected abstract bool TryToExtend(Vector3 closestPos, Vector3 target, out VectorNode extension);
public RRT GrowTree()
{
RRT tree = new RRT();
tree.AddNode(new VectorNode(start));
bool reachedGoal = false;
int iter = 0;
// Keep growing the tree until it contains the goal and we've
// grown for the required number of iterations.
while (!reachedGoal && iter < maxIters)
{
iter++;
// Get a random node somewhere near the goal
Vector3 sampleTarget = SamplePoint(goal);
// Get the closest node in the tree to the sample.
int closest = tree.GetClosestNodeIdx(sampleTarget);
// Create a new node between the closest node and the sample.
if (TryToExtend(tree.Nodes[closest].Pos, sampleTarget, out VectorNode extension))
{
tree.AddChildNode(closest, extension);
// If we haven't yet reached the goal, and the new node
// is very near the goal, add the goal to the tree.
if (!reachedGoal && (extension.Pos - goal).sqrMagnitude < goalTolerance * goalTolerance)
{
tree.AddChildNode(tree.NodeCount - 1, new VectorNode(goal));
reachedGoal = true;
}
}
}
// Try to force a final connection to the goal
// This step is not typical
if (!reachedGoal) {
int closest = tree.GetClosestNodeIdx(goal);
if (TryToExtend(tree.Nodes[closest].Pos, goal, out VectorNode extension))
{
tree.AddChildNode(closest, extension);
}
}
return tree;
}
}
/// A rapidly exploring random tree data structure
public class RRT : TreeGraph<VectorNode>
{
public int GetClosestNodeIdx(Vector3 target)
{
if (this.NodeCount == 0) {
Debug.LogError("Attempted to find closest node in empty RRT");
return -1;
}
int closestIdx = 0;
float closestDist = (Nodes[0].Pos - target).sqrMagnitude;
for (int i = 1; i < this.NodeCount; i++)
{
float nodeDist = (Nodes[i].Pos - target).sqrMagnitude;
if (nodeDist < closestDist)
{
closestDist = nodeDist;
closestIdx = i;
}
}
return closestIdx;
}
}
/// Tree data structure
using System;
using System.Collections.Generic;
using UnityEngine;
/// List of child nodes
[Serializable]
public class EdgeList
{
[SerializeField]
List<int> edges;
public EdgeList()
{
edges = new List<int>();
}
public int this[int i] { get => edges[i]; set => edges[i] = value; }
public List<int> Edges { get => edges; }
public int Count { get => edges.Count; }
}
/// Base class for nodes in tree
[Serializable]
public abstract class Node
{
[SerializeField]
EdgeList children;
public Node()
{
this.children = new EdgeList();
}
public EdgeList Children { get => children; }
}
/// Tree data structure
[Serializable]
public class TreeGraph<T> where T: Node
{
[SerializeField]
List<T> nodes;
public int NodeCount { get => nodes.Count; }
public List<T> Nodes { get => nodes; }
public T Root {
get {
if (nodes.Count == 0) {
return null;
} else {
return nodes[0];
}
}
}
public int ChildCount(int nodeIdx) {
return nodes[nodeIdx].Children.Count;
}
public TreeGraph()
{
nodes = new List<T>();
}
public void AddNode(T node)
{
nodes.Add(node);
}
public void RemoveNode(int index)
{
nodes.RemoveAt(index);
// Only indices smaller than the removed will have references to it
for (int i = 0; i < NodeCount; i++)
{
for (int j = ChildCount(i)-1; j >= 0; j--) {
// Correct all edges to nodes that came after
if (nodes[i].Children.Edges[j] > index) {
nodes[i].Children.Edges[j] -= 1;
} else if (nodes[i].Children.Edges[j] == index) {
nodes[i].Children.Edges.RemoveAt(j);
}
}
}
}
void RemoveSubtree(List<int> subtree)
{
if (subtree.Count == 0) {
return;
}
int index = subtree[0];
subtree.RemoveAt(0);
for (int i = 0; i < ChildCount(index); i++) {
subtree.Add(nodes[index].Children[i]);
}
RemoveNode(index);
// Update the subtree indices as index was removed
for (int i = 0; i < subtree.Count; i++) {
if (subtree[i] > index) {
subtree[i] -= 1;
}
}
RemoveSubtree(subtree);
}
public void RemoveSubtree(int index)
{
List<int> subtree = new List<int>();
subtree.Add(index);
RemoveSubtree(subtree);
}
public void AddChildNode(int parent, T child)
{
if (parent >= nodes.Count) {
Debug.LogError("Invalid parent node given");
} else {
nodes[parent].Children.Edges.Add(NodeCount);
nodes.Add(child);
}
}
public IEnumerator<int> TraverseSubtree(int index)
{
Stack<int> subtree = new Stack<int>();
subtree.Push(index);
while (subtree.Count > 0)
{
int current = subtree.Pop();
yield return current;
for (int i = 0; i < ChildCount(current); i++) {
subtree.Push(nodes[current].Children[i]);
}
}
}
public (int, int) FindParentAndChildIndex(int child)
{
if (child > NodeCount) {
Debug.LogError("Child index not in tree");
return (-1, -1);
}
for (int i = 0; i < NodeCount; i++) {
List<int> childList = Nodes[i].Children.Edges;
int childIndex = childList.FindIndex(c => c == child);
if (childIndex >= 0) {
return (i, childIndex);
}
}
return (-1, -1);
}
public void Clear()
{
this.nodes.Clear();
}
}
using System.Collections.Generic;
using UnityEngine;
struct TreeRenderNode
{
public int nodeIdx;
public float pathLength;
public TreeRenderNode(int nodeIdx, float pathLength)
{
this.nodeIdx = nodeIdx;
this.pathLength = pathLength;
}
}
[RequireComponent(typeof(MeshFilter), typeof(MeshRenderer))]
public class TreeGraphRenderer : MonoBehaviour
{
[SerializeField]
Camera cam;
[SerializeField]
float widthMultiplier = 1f;
Mesh treeMesh;
MeshRenderer treeRenderer;
TreeGraph<VectorNode> tree;
List<Vector3> vertices;
List<int> triangles;
List<Vector2> uvs;
float maximumPathLength;
public TreeGraph<VectorNode> Tree {
get => tree;
set {
Clear();
tree = value;
if (tree != null) {
ComputeMaxPathLength(tree);
DrawTree();
Apply();
}
}
}
void Awake()
{
if (cam == null) {
cam = Camera.main;
}
treeRenderer = GetComponent<MeshRenderer>();
GetComponent<MeshFilter>().mesh = treeMesh = new Mesh();
treeMesh.name = "Tree Render Mesh";
}
void Update()
{
if (Tree != null) {
Clear();
DrawTree();
Apply();
}
}
/// Performs a breadth-first tree traversal to find the longest path length
void ComputeMaxPathLength(TreeGraph<VectorNode> tree)
{
float pathLength = 0f;
Queue<TreeRenderNode> searchQueue = new Queue<TreeRenderNode>();
searchQueue.Enqueue(new TreeRenderNode(0, 0f));
while (searchQueue.Count > 0)
{
TreeRenderNode renderNode = searchQueue.Dequeue();
int node = renderNode.nodeIdx;
int childCount = tree.ChildCount(node);
VectorNode vecNode = tree.Nodes[node];
if (renderNode.pathLength > pathLength)
{
pathLength = renderNode.pathLength;
}
for (int i = 0; i < childCount; i++)
{
int childIdx = vecNode.Children[i];
float distToChild = (vecNode.Pos - tree.Nodes[childIdx].Pos).magnitude;
searchQueue.Enqueue(new TreeRenderNode(childIdx, renderNode.pathLength + distToChild));
}
}
maximumPathLength = pathLength;
}
void OnDestroy()
{
Clear();
}
public void Clear()
{
treeMesh.Clear();
vertices = ListPool<Vector3>.Get();
triangles = ListPool<int>.Get();
uvs = ListPool<Vector2>.Get();
}
public void Apply()
{
treeMesh.SetVertices(vertices);
treeMesh.SetTriangles(triangles, 0);
treeMesh.SetUVs(0, uvs);
ListPool<Vector3>.Add(vertices);
ListPool<int>.Add(triangles);
ListPool<Vector2>.Add(uvs);
treeMesh.RecalculateNormals();
}
void DrawTree()
{
if (tree == null)
{
Debug.LogWarning("No tree assigned to tree graph renderer.");
return;
}
// Iterate breadth first through the tree
// Subdivide the line according to children count
Queue<TreeRenderNode> nodeQueue = new Queue<TreeRenderNode>();
nodeQueue.Enqueue(new TreeRenderNode(0, 0f));
while (nodeQueue.Count > 0)
{
TreeRenderNode renderNode = nodeQueue.Dequeue();
int node = renderNode.nodeIdx;
int childCount = tree.ChildCount(node);
VectorNode vecNode = tree.Nodes[node];
float nodePathCompletion = (maximumPathLength - renderNode.pathLength) / maximumPathLength;
float width = widthMultiplier * nodePathCompletion;
Vector3 v1 = vecNode.Pos - cam.transform.right * width / 2;
Vector3 v2 = vecNode.Pos + cam.transform.right * width / 2;
float step = 1f / childCount;
for (int i = 0; i < childCount; i++)
{
int childIdx = vecNode.Children[i];
VectorNode childNode = tree.Nodes[childIdx];
float distToChild = (vecNode.Pos - tree.Nodes[childIdx].Pos).magnitude;
float childPathLength = renderNode.pathLength + distToChild;
float childPathCompletion = (maximumPathLength - childPathLength) / maximumPathLength;
if (tree.ChildCount(childIdx) > 0) {
float childWidth = widthMultiplier * childPathCompletion;
// We'll iterate over this node too
nodeQueue.Enqueue(new TreeRenderNode(childIdx, childPathLength));
// Draw quad
Vector3 v3 = childNode.Pos - cam.transform.right * childWidth / 2;
Vector3 v4 = childNode.Pos + cam.transform.right * childWidth / 2;
AddQuad(v1, v2, v3, v4);
AddQuadUV(0, 1, nodePathCompletion, childPathCompletion);
} else {
AddTriangle(v1, v2, childNode.Pos);
AddTriangleUV(
new Vector2(0, nodePathCompletion),
new Vector2(1, nodePathCompletion),
new Vector2(0.5f, childPathCompletion)
);
}
}
}
}
public void AddTriangle(Vector3 v1, Vector3 v2, Vector3 v3)
{
int vertexIndex = vertices.Count;
vertices.Add(v1);
vertices.Add(v2);
vertices.Add(v3);
triangles.Add(vertexIndex + 0);
triangles.Add(vertexIndex + 1);
triangles.Add(vertexIndex + 2);
}
public void AddDoubleSidedTriangle(Vector3 v1, Vector3 v2, Vector3 v3)
{
int vertexIndex = vertices.Count;
vertices.Add(v1);
vertices.Add(v2);
vertices.Add(v3);
triangles.Add(vertexIndex + 0);
triangles.Add(vertexIndex + 1);
triangles.Add(vertexIndex + 2);
triangles.Add(vertexIndex + 0);
triangles.Add(vertexIndex + 2);
triangles.Add(vertexIndex + 1);
}
public void AddQuad(Vector3 v1, Vector3 v2, Vector3 v3, Vector3 v4)
{
int vertexIndex = vertices.Count;
vertices.Add(v1);
vertices.Add(v2);
vertices.Add(v3);
vertices.Add(v4);
triangles.Add(vertexIndex + 0);
triangles.Add(vertexIndex + 2);
triangles.Add(vertexIndex + 1);
triangles.Add(vertexIndex + 1);
triangles.Add(vertexIndex + 2);
triangles.Add(vertexIndex + 3);
}
public void AddDoubleSidedQuad(Vector3 v1, Vector3 v2, Vector3 v3, Vector3 v4)
{
int vertexIndex = vertices.Count;
vertices.Add(v1);
vertices.Add(v2);
vertices.Add(v3);
vertices.Add(v4);
triangles.Add(vertexIndex + 0);
triangles.Add(vertexIndex + 2);
triangles.Add(vertexIndex + 1);
triangles.Add(vertexIndex + 1);
triangles.Add(vertexIndex + 2);
triangles.Add(vertexIndex + 3);
triangles.Add(vertexIndex + 0);
triangles.Add(vertexIndex + 1);
triangles.Add(vertexIndex + 2);
triangles.Add(vertexIndex + 2);
triangles.Add(vertexIndex + 1);
triangles.Add(vertexIndex + 3);
}
public void AddTriangleUV (Vector2 uv1, Vector2 uv2, Vector2 uv3) {
uvs.Add(uv1);
uvs.Add(uv2);
uvs.Add(uv3);
}
public void AddQuadUV (float uMin, float uMax, float vMin, float vMax) {
uvs.Add(new Vector2(uMin, vMin));
uvs.Add(new Vector2(uMax, vMin));
uvs.Add(new Vector2(uMin, vMax));
uvs.Add(new Vector2(uMax, vMax));
}
}
using UnityEngine;
public class VectorNode : Node
{
Vector3 pos;
public Vector3 Pos { get => pos; }
public VectorNode(Vector3 pos) : base()
{
this.pos = pos;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment