Created
January 16, 2019 17:28
-
-
Save IAD/9f678cde73740941d3f86982aff365ea to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package fsm | |
import ( | |
"context" | |
"reflect" | |
"runtime" | |
"strings" | |
) | |
// StateFunc function that should return next state func or nil | |
type StateFunc func() StateFunc | |
// NewFSM constructor func | |
func NewFSM( | |
beforeTrigger func(previousState, currentState string), | |
afterTrigger func(currentStat, nextState string), | |
) *Fsm { | |
return &Fsm{ | |
done: make(chan struct{}, 0), | |
state: make(chan StateFunc, 1), | |
stack: make([]StateFunc, 0), | |
beforeTrigger: beforeTrigger, | |
afterTrigger: afterTrigger, | |
} | |
} | |
// Fsm is a Final State Machine | |
type Fsm struct { | |
done chan struct{} | |
state chan StateFunc | |
current StateFunc | |
stack []StateFunc | |
beforeTrigger func(previousState, currentState string) | |
afterTrigger func(currentStat, nextState string) | |
} | |
// Start called to start fsm loop with init state | |
func (fsm *Fsm) Start(ctx context.Context, state StateFunc) error { | |
defer close(fsm.done) | |
fsm.state <- state | |
for { | |
if ctx.Err() != nil { | |
return nil | |
} | |
select { | |
case newState := <-fsm.state: | |
// nil means that state ends and fsm should be returned to a previous state | |
if newState == nil { | |
newState = fsm.popState() | |
// exit if there are no previous state in the stack | |
if newState == nil { | |
return nil | |
} | |
} | |
// call a before action trigger | |
if fsm.beforeTrigger != nil { | |
fsm.beforeTrigger(getStateName(fsm.current), getStateName(newState)) | |
} | |
fsm.current = newState | |
// action | |
nextState := newState() | |
fsm.state <- nextState | |
// call an after action trigger | |
if fsm.afterTrigger != nil { | |
fsm.afterTrigger(getStateName(newState), getStateName(nextState)) | |
} | |
case <-ctx.Done(): | |
return nil | |
} | |
} | |
} | |
func (fsm *Fsm) Done() <-chan struct{} { | |
return fsm.done | |
} | |
// PushState used to push a new state into stack | |
func (fsm *Fsm) PushState(state StateFunc) StateFunc { | |
fsm.stack = append(fsm.stack, fsm.current) | |
return state | |
} | |
// popState used to pop previous state from the stack | |
func (fsm *Fsm) popState() StateFunc { | |
if len(fsm.stack) == 0 { | |
return nil | |
} | |
state := fsm.stack[len(fsm.stack)-1] | |
fsm.stack = fsm.stack[:len(fsm.stack)-1] | |
return state | |
} | |
func getStateName(f StateFunc) string { | |
if f == nil { | |
return "" | |
} | |
fullName := runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() | |
if fullName == "" { | |
return "" | |
} | |
lastIndex1 := strings.LastIndex(fullName, ".") | |
lastIndex2 := strings.LastIndex(fullName, "-") | |
return strings.Replace(fullName[lastIndex1+1:lastIndex2], ")", "", -1) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package fsm | |
import ( | |
"context" | |
"testing" | |
"time" | |
"github.com/stretchr/testify/assert" | |
) | |
func TestFsmStackSimple(t *testing.T) { | |
ctx, _ := context.WithCancel(context.Background()) | |
fsm := NewFSM(nil, nil) | |
count := 0 | |
increase := func() StateFunc { | |
count++ | |
return nil | |
} | |
idle := func() StateFunc { | |
if count < 10 { | |
return fsm.PushState(increase) | |
} | |
return nil | |
} | |
go fsm.Start(ctx, idle) | |
time.Sleep(time.Millisecond * 1) | |
assert.Equal(t, 10, count) | |
} | |
func TestFsmStack(t *testing.T) { | |
ctx, _ := context.WithCancel(context.Background()) | |
fsm := NewFSM(nil, nil) | |
a := 0 | |
b := 0 | |
increaseA := func() StateFunc { | |
a++ | |
return nil | |
} | |
decreaseA := func() StateFunc { | |
a -= 10 | |
return nil | |
} | |
increaseB := func() StateFunc { | |
if a >= 10 { | |
return fsm.PushState(decreaseA) | |
} | |
b++ | |
return nil | |
} | |
idle := func() StateFunc { | |
if b == 10 { | |
return nil | |
} | |
if a < 10 { | |
return fsm.PushState(increaseA) | |
} | |
return fsm.PushState(increaseB) | |
} | |
go fsm.Start(ctx, idle) | |
time.Sleep(time.Millisecond * 10) | |
assert.Equal(t, 10, b) | |
assert.Equal(t, 0, a) | |
} | |
func TestGetStateName(t *testing.T) { | |
o := &Fsm{} | |
f := o.popState | |
expected := "popState" | |
assert.Equal(t, expected, getStateName(f)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment