Skip to content

Instantly share code, notes, and snippets.

@IAD
Created January 16, 2019 17:28
Show Gist options
  • Save IAD/9f678cde73740941d3f86982aff365ea to your computer and use it in GitHub Desktop.
Save IAD/9f678cde73740941d3f86982aff365ea to your computer and use it in GitHub Desktop.
package fsm
import (
"context"
"reflect"
"runtime"
"strings"
)
// StateFunc function that should return next state func or nil
type StateFunc func() StateFunc
// NewFSM constructor func
func NewFSM(
beforeTrigger func(previousState, currentState string),
afterTrigger func(currentStat, nextState string),
) *Fsm {
return &Fsm{
done: make(chan struct{}, 0),
state: make(chan StateFunc, 1),
stack: make([]StateFunc, 0),
beforeTrigger: beforeTrigger,
afterTrigger: afterTrigger,
}
}
// Fsm is a Final State Machine
type Fsm struct {
done chan struct{}
state chan StateFunc
current StateFunc
stack []StateFunc
beforeTrigger func(previousState, currentState string)
afterTrigger func(currentStat, nextState string)
}
// Start called to start fsm loop with init state
func (fsm *Fsm) Start(ctx context.Context, state StateFunc) error {
defer close(fsm.done)
fsm.state <- state
for {
if ctx.Err() != nil {
return nil
}
select {
case newState := <-fsm.state:
// nil means that state ends and fsm should be returned to a previous state
if newState == nil {
newState = fsm.popState()
// exit if there are no previous state in the stack
if newState == nil {
return nil
}
}
// call a before action trigger
if fsm.beforeTrigger != nil {
fsm.beforeTrigger(getStateName(fsm.current), getStateName(newState))
}
fsm.current = newState
// action
nextState := newState()
fsm.state <- nextState
// call an after action trigger
if fsm.afterTrigger != nil {
fsm.afterTrigger(getStateName(newState), getStateName(nextState))
}
case <-ctx.Done():
return nil
}
}
}
func (fsm *Fsm) Done() <-chan struct{} {
return fsm.done
}
// PushState used to push a new state into stack
func (fsm *Fsm) PushState(state StateFunc) StateFunc {
fsm.stack = append(fsm.stack, fsm.current)
return state
}
// popState used to pop previous state from the stack
func (fsm *Fsm) popState() StateFunc {
if len(fsm.stack) == 0 {
return nil
}
state := fsm.stack[len(fsm.stack)-1]
fsm.stack = fsm.stack[:len(fsm.stack)-1]
return state
}
func getStateName(f StateFunc) string {
if f == nil {
return ""
}
fullName := runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
if fullName == "" {
return ""
}
lastIndex1 := strings.LastIndex(fullName, ".")
lastIndex2 := strings.LastIndex(fullName, "-")
return strings.Replace(fullName[lastIndex1+1:lastIndex2], ")", "", -1)
}
package fsm
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestFsmStackSimple(t *testing.T) {
ctx, _ := context.WithCancel(context.Background())
fsm := NewFSM(nil, nil)
count := 0
increase := func() StateFunc {
count++
return nil
}
idle := func() StateFunc {
if count < 10 {
return fsm.PushState(increase)
}
return nil
}
go fsm.Start(ctx, idle)
time.Sleep(time.Millisecond * 1)
assert.Equal(t, 10, count)
}
func TestFsmStack(t *testing.T) {
ctx, _ := context.WithCancel(context.Background())
fsm := NewFSM(nil, nil)
a := 0
b := 0
increaseA := func() StateFunc {
a++
return nil
}
decreaseA := func() StateFunc {
a -= 10
return nil
}
increaseB := func() StateFunc {
if a >= 10 {
return fsm.PushState(decreaseA)
}
b++
return nil
}
idle := func() StateFunc {
if b == 10 {
return nil
}
if a < 10 {
return fsm.PushState(increaseA)
}
return fsm.PushState(increaseB)
}
go fsm.Start(ctx, idle)
time.Sleep(time.Millisecond * 10)
assert.Equal(t, 10, b)
assert.Equal(t, 0, a)
}
func TestGetStateName(t *testing.T) {
o := &Fsm{}
f := o.popState
expected := "popState"
assert.Equal(t, expected, getStateName(f))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment