Skip to content

Instantly share code, notes, and snippets.

@nathan815
Created March 31, 2022 05:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nathan815/426c0d5dabc467f6a0a12dc900218df0 to your computer and use it in GitHub Desktop.
Save nathan815/426c0d5dabc467f6a0a12dc900218df0 to your computer and use it in GitHub Desktop.
Kotlin state machine
package me.nathanjohnson
import android.util.Log
import kotlin.reflect.KClass
/**
* Transition from one state to another on a specific event.
*/
data class StateTransition<StateT, ResourceT>(
val event: KClass<*>,
val from: StateT,
val to: StateT,
/** Additional check before the transition occurs. Returns false to stop this transition. */
val guard: ((event: StateMachine.Event, resource: ResourceT) -> Boolean)? = null,
/** Action (side effect) when this transition occurs. Optionally returns a modified resource object. */
val action: ((event: StateMachine.Event, resource: ResourceT) -> ResourceT?)? = null,
)
/**
* Represents a finite state machine (FSM).
*
* The only way to move from one state to another is by triggering an event,
* which in turn causes a transition from the current state to another.
*
* @param StateT Type of state being maintained
* @param ResourceT Type of resource for which we are maintaining state of
*/
class StateMachine<StateT, ResourceT>(
initialState: StateT,
setup: (StateMachine<StateT, ResourceT>.() -> Unit)
) {
open class Event
enum class TriggerResult {
NoTransition,
StoppedByGuard,
StateUpdated,
}
data class TriggerOutput<ResourceT>(val result: TriggerResult, val resource: ResourceT? = null)
private var transitions: List<StateTransition<StateT, ResourceT>> = listOf()
private var onTransitionCallback: (
transition: StateTransition<StateT, ResourceT>,
resource: ResourceT
) -> ResourceT = { _, res -> res }
var currentState: StateT = initialState
init {
this.setup()
}
/**
* Function passed to this is called whenever a transition occurs.
* Should return resource object with any needed modifications.
* */
fun onTransition(
func: (transition: StateTransition<StateT, ResourceT>, resource: ResourceT) -> ResourceT = { _, res -> res }
) {
onTransitionCallback = func
}
fun addTransitions(vararg ts: StateTransition<StateT, ResourceT>) {
transitions = transitions + ts.toList()
}
/**
* Trigger an event in the state machine.
*
* If the guard function for the matching state transition returns false, the
* transition will be stopped.
*
* @param event An event to trigger
*/
fun trigger(event: Event, resource: ResourceT): TriggerOutput<ResourceT> {
Log.i(TAG, "Trigger event $event - Current state is $currentState")
val transition = transitions.find { it.event == event::class && it.from == currentState }
if (transition == null) {
Log.w(TAG, "No transition exists for $event in current state $currentState")
return TriggerOutput(TriggerResult.NoTransition)
}
if (transition.guard?.invoke(event, resource) == false) {
Log.i(
TAG,
"Transition stopped by guard. Transition: $transition, Event: $event, Resource: $resource"
)
return TriggerOutput(TriggerResult.StoppedByGuard)
}
currentState = transition.to
val resourceFromAction = transition.action?.invoke(event, resource)
val finalResource = onTransitionCallback(transition, resourceFromAction ?: resource)
return TriggerOutput(
result = TriggerResult.StateUpdated,
resource = finalResource
)
}
companion object {
private val TAG = StateMachine::class.java.simpleName
}
}
package com.justlight.sunflower.util
import org.junit.Assert.assertEquals
import org.junit.Assert.assertNull
import org.spekframework.spek2.Spek
import org.spekframework.spek2.style.specification.describe
enum class TestState {
State1,
State2,
State3,
State4;
companion object {
val INITIAL = State1
}
}
data class TestResource(val state: TestState, val number: Int = 0)
/**
* Contains events for the state machine.
*
* Using a class for each event allows passing data to actions and guards since the event is passed to them.
* If no parameters need to be sent with events, a simple enum could be used instead.
*/
object TestEvents {
sealed class BaseEvent : StateMachine.Event()
object BasicEvent : BaseEvent()
object AnotherBasicEvent : BaseEvent()
class EventWithNumber(val someNumber: Int) : BaseEvent()
}
class StateMachineTest : Spek({
describe("trigger") {
describe("no matching transition") {
it("does nothing") {
val resource = TestResource(state = TestState.INITIAL)
val stateMachine = StateMachine<TestState, TestResource>(TestState.INITIAL) {
onTransition { transition, res -> res.copy(state = transition.to) }
addTransitions(
StateTransition(
event = TestEvents.BasicEvent::class,
from = TestState.State2,
to = TestState.State3,
),
)
}
val output = stateMachine.trigger(TestEvents.BasicEvent, resource)
assertEquals(TestState.INITIAL, stateMachine.currentState)
assertEquals(StateMachine.TriggerResult.NoTransition, output.result)
assertNull(output.resource)
}
}
describe("transition with no action or guard") {
it("changes state") {
val resource = TestResource(state = TestState.INITIAL)
val stateMachine = StateMachine<TestState, TestResource>(TestState.INITIAL) {
onTransition { transition, res -> res.copy(state = transition.to) }
addTransitions(
StateTransition(
event = TestEvents.BasicEvent::class,
from = TestState.State1,
to = TestState.State2,
),
)
}
val output = stateMachine.trigger(TestEvents.BasicEvent, resource)
assertEquals(TestState.State2, stateMachine.currentState)
assertEquals(resource.copy(state = TestState.State2), output.resource)
}
}
describe("transition with action") {
it("changes state, executes action, and returns the output in TriggerOutput object") {
val resource = TestResource(state = TestState.INITIAL, number = 0)
val stateMachine = StateMachine<TestState, TestResource>(TestState.INITIAL) {
onTransition { transition, res -> res.copy(state = transition.to) }
addTransitions(
StateTransition(
event = TestEvents.BasicEvent::class,
from = TestState.State1,
to = TestState.State2,
),
StateTransition(
event = TestEvents.EventWithNumber::class,
from = TestState.State2,
to = TestState.State3,
action = { e: StateMachine.Event, r: TestResource ->
r.copy(number = (e as TestEvents.EventWithNumber).someNumber)
}
),
StateTransition(
event = TestEvents.AnotherBasicEvent::class,
from = TestState.State3,
to = TestState.State4,
action = { _, _ -> null }
),
StateTransition(
event = TestEvents.BasicEvent::class,
from = TestState.State4,
to = TestState.State1
)
)
}
// 1 to 2
val output0 = stateMachine.trigger(TestEvents.BasicEvent, resource)
assertEquals(TestState.State2, stateMachine.currentState)
assertEquals(resource.copy(state = TestState.State2), output0.resource)
// 2 to 3
val output1 =
stateMachine.trigger(TestEvents.EventWithNumber(someNumber = 4815), resource)
assertEquals(TestState.State3, stateMachine.currentState)
assertEquals(
"action should run and return modified resource",
resource.copy(state = TestState.State3, number = 4815),
output1.resource
)
// 3 to 4
val output2 = stateMachine.trigger(TestEvents.AnotherBasicEvent, resource)
assertEquals(TestState.State4, stateMachine.currentState)
assertEquals(
resource.copy(state = TestState.State4),
output2.resource
)
// 4 to 1
val output3 = stateMachine.trigger(TestEvents.BasicEvent, resource)
assertEquals(TestState.State1, stateMachine.currentState)
assertEquals(
resource.copy(state = TestState.State1),
output3.resource
)
}
}
describe("transition with a guard and action") {
val stateMachine = StateMachine<TestState, TestResource>(TestState.INITIAL) {
onTransition { transition, res -> res.copy(state = transition.to) }
addTransitions(
StateTransition(
event = TestEvents.BasicEvent::class,
from = TestState.State1,
to = TestState.State2,
),
StateTransition(
event = TestEvents.EventWithNumber::class,
from = TestState.State2,
to = TestState.State3,
action = { e: StateMachine.Event, r: TestResource ->
r.copy(number = (e as TestEvents.EventWithNumber).someNumber)
},
guard = { e: StateMachine.Event, _: TestResource ->
(e as TestEvents.EventWithNumber).someNumber > 1
},
),
StateTransition(
event = TestEvents.AnotherBasicEvent::class,
from = TestState.State3,
to = TestState.State4,
action = { _, _ -> null }
),
StateTransition(
event = TestEvents.BasicEvent::class,
from = TestState.State4,
to = TestState.State1
)
)
}
describe("guard returns false") {
it("stops transition and action is not executed") {
val resource = TestResource(state = TestState.INITIAL)
stateMachine.trigger(TestEvents.BasicEvent, resource)
assertEquals(TestState.State2, stateMachine.currentState)
val output1 =
stateMachine.trigger(TestEvents.EventWithNumber(someNumber = -1), resource)
assertEquals(StateMachine.TriggerResult.StoppedByGuard, output1.result)
assertEquals(TestState.State2, stateMachine.currentState)
assertNull(output1.resource)
}
}
describe("guard returns true") {
it("changes state and executes action") {
val resource = TestResource(state = TestState.INITIAL)
stateMachine.trigger(TestEvents.BasicEvent, resource)
assertEquals(TestState.State2, stateMachine.currentState)
val output1 =
stateMachine.trigger(TestEvents.EventWithNumber(someNumber = 5), resource)
assertEquals(TestState.State3, stateMachine.currentState)
assertEquals(
"action should run and return modified resource",
resource.copy(state = TestState.State3, number = 5),
output1.resource
)
}
}
}
}
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment