Created
September 9, 2012 19:00
-
-
Save aodag/3686516 to your computer and use it in GitHub Desktop.
State Machine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding:utf-8 -*- | |
from zope.interface import Interface, directlyProvides | |
from zope.interface.adapter import AdapterRegistry | |
class IStateChangeSubscriber(Interface): | |
def __call__(event): | |
""" """ | |
class StateEvent(object): | |
""" implementation of every state events """ | |
def __init__(self, state, model): | |
self.state = state | |
self.model = model | |
class StateMachine(object): | |
def __init__(self): | |
self.states = {} | |
self.transitions = {} | |
self.current_state = None | |
self.registry = AdapterRegistry() | |
def __call__(self, model, action): | |
transition = self.transitions[action] | |
self.current_state.leave(model) | |
self.notify_leave_state(model, self.current_state) | |
self.current_state = transition.dest | |
self.notify_enter_state(model, self.current_state) | |
self.current_state.enter(model) | |
def notify_leave_state(self, model, state): | |
event = StateEvent(state, model) | |
self.notify_state_change(event, state.leave_event_iface) | |
def notify_enter_state(self, model, state): | |
event = StateEvent(state, model) | |
self.notify_state_change(event, state.enter_event_iface) | |
def notify_state_change(self, event, event_iface): | |
directlyProvides(event, event_iface) | |
for subscriber in self.registry.subscriptions([event_iface], | |
IStateChangeSubscriber): | |
subscriber(event) | |
def maybe_state(self, name): | |
if isinstance(name, str): | |
return self.states[name] | |
return name | |
def subscribe_enter_state(self, state, subscriber): | |
state = self.maybe_state(state) | |
self.registry.subscribe([state.enter_event_iface], | |
IStateChangeSubscriber, | |
subscriber) | |
def subscribe_leave_state(self, state, subscriber): | |
state = self.maybe_state(state) | |
self.registry.subscribe([state.leave_event_iface], | |
IStateChangeSubscriber, | |
subscriber) | |
def add_state(self, name): | |
state = State(name) | |
self.states[name] = state | |
return state | |
def add_transtion(self, name, src, dest): | |
src_state = self.states[src] | |
dest_state = self.states[dest] | |
transition = Transition(name, src_state, dest_state) | |
self.transitions[name] = transition | |
return transition | |
def start(self, model, name): | |
self.current_state = self.states[name] | |
self.current_state.enter(model) | |
class State(object): | |
def __init__(self, name): | |
self.name = name | |
self.enter_event_iface = type(Interface)(name + "_EnterStateEvent", | |
(Interface,), {}) | |
self.leave_event_iface = type(Interface)(name + "_LeaveStateEvent", | |
(Interface,), {}) | |
def enter(self, model): | |
pass | |
def leave(self, model): | |
pass | |
class Transition(object): | |
def __init__(self, name, src, dest): | |
self.name = name | |
self.src = src | |
self.dest = dest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from unittest import mock | |
class StateTests(unittest.TestCase): | |
def _getTarget(self): | |
from statemachine import State | |
return State | |
def _makeOne(self, *args, **kwargs): | |
return self._getTarget()(*args, **kwargs) | |
def test_init(self): | |
target = self._makeOne("state1") | |
self.assertEqual(target.name, "state1") | |
self.assertEqual(target.enter_event_iface.__name__, | |
"state1_EnterStateEvent") | |
self.assertEqual(target.leave_event_iface.__name__, | |
"state1_LeaveStateEvent") | |
class StateMachineTests(unittest.TestCase): | |
def _getTarget(self): | |
from statemachine import StateMachine | |
return StateMachine | |
def _makeOne(self, *args, **kwargs): | |
return self._getTarget()(*args, **kwargs) | |
def test_it(self): | |
DummySubscriber.reset() | |
model = object() | |
target = self._makeOne() | |
target.add_state('s1') | |
target.add_state('s2') | |
target.subscribe_enter_state('s2', | |
DummySubscriber('s2', model)) | |
target.subscribe_enter_state('s2', | |
DummySubscriber('s2', model)) | |
target.subscribe_leave_state('s1', | |
DummySubscriber('s1', model)) | |
target.add_transtion('t1', 's1', 's2') | |
target.start(model, 's1') | |
target(model, 't1') | |
DummySubscriber.assertCalledAll() | |
class DummySubscriber(object): | |
instances = [] | |
@classmethod | |
def assertCalledAll(cls): | |
assert all(i.called for i in cls.instances) | |
@classmethod | |
def reset(cls): | |
cls.instances = [] | |
def __init__(self, state_name, model): | |
self.state_name = state_name | |
self.model = model | |
self.called = False | |
type(self).instances.append(self) | |
def __call__(self, event): | |
assert event.state.name == self.state_name | |
assert event.model == self.model | |
self.called = True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment