Skip to content

Instantly share code, notes, and snippets.

@p3nGu1nZz
Created March 1, 2024 18:50
Show Gist options
  • Save p3nGu1nZz/93b498fe2a4f2bcf2ebf2506345587fc to your computer and use it in GitHub Desktop.
Save p3nGu1nZz/93b498fe2a4f2bcf2ebf2506345587fc to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using UnityEngine;
namespace HFSM
{
public abstract class StateMachine<T>
{
T ctx;
StateMachine<T> current;
StateMachine<T> init;
StateMachine<T> parent;
public StateMachine<T> Parent { get { return parent; } }
public StateMachine<T> Current { get { return current; } }
Dictionary<Type, StateMachine<T>> states = new Dictionary<Type, StateMachine<T>>();
Dictionary<int, StateMachine<T>> transitions = new Dictionary<int, StateMachine<T>>();
public Dictionary<Type, StateMachine<T>> States { get { return states; } }
public Dictionary<int, StateMachine<T>> Transitions { get { return transitions; } }
public readonly struct Triggers { }
public void Bind(T ctx)
{
this.ctx = ctx;
OnLoad(ctx);
}
public void Start()
{
OnStart(ctx);
}
public void Enter()
{
OnEnter(ctx);
if (current == null && init != null)
{
current = init;
}
current?.Enter();
}
public void FixedUpdate()
{
OnFixedUpdate(ctx);
current?.FixedUpdate();
}
public void Update()
{
OnUpdate(ctx);
current?.Update();
}
public void LateUpdate()
{
OnLateUpdate(ctx);
current?.LateUpdate();
}
public void AnimateUpdate()
{
OnAnimateUpdate(ctx);
current?.AnimateUpdate();
}
public void CollisionEnter(Collision collision)
{
OnCollisionEnter(ctx, collision);
current?.CollisionEnter(collision);
}
public void CollisionExit(Collision collision)
{
OnCollisionExit(ctx, collision);
current?.CollisionExit(collision);
}
public void Exit()
{
current?.Exit();
OnExit(ctx);
}
protected virtual void OnLoad(T context) { }
protected virtual void OnStart(T context) { }
protected virtual void OnEnter(T context) { }
protected virtual void OnFixedUpdate(T context) { }
protected virtual void OnUpdate(T context) { }
protected virtual void OnLateUpdate(T context) { }
protected virtual void OnAnimateUpdate(T context) { }
protected virtual void OnCollisionEnter(T context, Collision collision) { }
protected virtual void OnCollisionExit(T context, Collision collision) { }
protected virtual void OnExit(T context) { }
public void Load(StateMachine<T> state)
{
if (states.Count == 0)
{
init = state;
}
state.parent = this;
if (ctx != null)
{
state.Bind(ctx);
}
try
{
states.Add(state.GetType(), state);
}
catch (ArgumentException)
{
throw new Exception($"State {GetType()} has substate of type {state.GetType()}");
}
}
public void AddTransition(StateMachine<T> from, StateMachine<T> to, int trigger)
{
if (!states.TryGetValue(from.GetType(), out _))
{
throw new Exception($"State {GetType()} missing substate for {from.GetType()}");
}
if (!states.TryGetValue(to.GetType(), out _))
{
throw new Exception($"State {GetType()} missing substate for {to.GetType()}");
}
try
{
from.transitions.Add(trigger, to);
}
catch (ArgumentException)
{
throw new Exception($"State {from.GetType()} duplicate transition trigger {trigger}");
}
}
public void AddChildTransition(StateMachine<T> from, StateMachine<T> to, int trigger)
{
if (!states.TryGetValue(from.GetType(), out _))
{
throw new Exception($"State {GetType()} missing substate for {from.GetType()}");
}
if (!to.parent.states.TryGetValue(to.GetType(), out _))
{
throw new Exception($"State {to.parent.GetType()} missing substate for {to.GetType()}");
}
try
{
from.transitions.Add(trigger, to);
}
catch (ArgumentException)
{
throw new Exception($"State {from.GetType()} duplicate transition trigger {trigger}");
}
}
public void Trigger(int trigger)
{
var root = this;
while (root?.parent != null)
{
root = root.parent;
}
while (root != null)
{
if (root.transitions.TryGetValue(trigger, out StateMachine<T> toState))
{
root.parent?.Change(toState);
return;
}
root = root.current;
}
}
private void Change(StateMachine<T> state)
{
current?.Exit();
if (states.TryGetValue(state.GetType(), out _))
{
current = states[state.GetType()];
current.Enter();
return;
}
if (state.parent.states.TryGetValue(state.GetType(), out _))
{
current.parent.current = state.parent;
state.parent.current = state.parent.states[state.GetType()];
state.parent.current.parent.Enter();
return;
}
throw new Exception($"Missing state {state.GetType()} in {state.parent.GetType()}");
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment