Created
July 31, 2014 01:59
-
-
Save harlowja/e87017a585a706166d70 to your computer and use it in GitHub Desktop.
nested_fsm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# Copyright (C) 2014 Yahoo! Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); you may | |
# not use this file except in compliance with the License. You may obtain | |
# a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | |
# License for the specific language governing permissions and limitations | |
# under the License. | |
import logging | |
import collections | |
try: | |
from collections import OrderedDict # noqa | |
except ImportError: | |
from ordereddict import OrderedDict # noqa | |
import prettytable | |
import six | |
from taskflow import exceptions as excp | |
LOG = logging.getLogger(__name__) | |
class _Jump(object): | |
"""A FSM transition tracks this data while jumping.""" | |
def __init__(self, name, on_enter, on_exit): | |
self.name = name | |
self.on_enter = on_enter | |
self.on_exit = on_exit | |
# Result of processing an event... | |
_Effect = collections.namedtuple("_Effect", | |
"reaction,terminal,nested_machine") | |
class NotInitialized(excp.TaskFlowException): | |
"""Error raised when an action is attempted on a not inited machine.""" | |
class FSM(object): | |
"""A finite state machine. | |
This state machine can be used to automatically run a given set of | |
transitions and states in response to events (either from callbacks or from | |
generator/iterator send() values, see PEP 342). On each triggered event, a | |
on_enter and on_exit callback can also be provided which will be called to | |
perform some type of action on leaving a prior state and before entering a | |
new state. | |
NOTE(harlowja): reactions will *only* be called when the generator/iterator | |
from run_iter() does *not* send back a new event (they will always be | |
called if the run() method is used). This allows for two unique ways (these | |
ways can also be intermixed) to use this state machine when using | |
run_iter(); one where *external* events trigger the next state transition | |
and one where *internal* reaction callbacks trigger the next state | |
transition. The other way to use this state machine is to skip using run() | |
or run_iter() completely and use the process_event() method explicitly and | |
trigger the events via some *external* functionality. | |
""" | |
def __init__(self, start_state): | |
self._transitions = {} | |
self._states = OrderedDict() | |
self._start_state = start_state | |
self._current = None | |
@property | |
def start_state(self): | |
return self._start_state | |
@property | |
def current_state(self): | |
if self._current is not None: | |
return self._current.name | |
return None | |
@property | |
def terminated(self): | |
"""Returns whether the state machine is in a terminal state.""" | |
if self._current is None: | |
return False | |
return self._states[self._current.name]['terminal'] | |
def add_state(self, state, terminal=False, nested_machine=None): | |
"""Adds a given state to the state machine.""" | |
if nested_machine is not None and not isinstance(nested_machine, FSM): | |
raise ValueError( | |
"Nested state machines must themselves be state machines") | |
if state not in self._states: | |
self._states[state] = { | |
'terminal': bool(terminal), | |
'reactions': {}, | |
'nested_machine': nested_machine, | |
} | |
self._transitions[state] = OrderedDict() | |
else: | |
raise excp.Duplicate("State '%s' already defined" % state) | |
def add_reaction(self, state, event, reaction, *args, **kwargs): | |
"""Adds a reaction that may get triggered by the given event & state. | |
Reaction callbacks may (depending on how the state machine is ran) be | |
used after an event is processed (and a transition occurs) to cause the | |
machine to react to the newly arrived at stable state. | |
These callbacks are expected to accept three default positional | |
parameters (although more can be passed in via *args and **kwargs, | |
these will automatically get provided to the callback when it is | |
activated *ontop* of the three default). The three default parameters | |
are the last stable state, the new stable state and the event that | |
caused the transition to this new stable state to be arrived at. | |
The expected result of a callback is expected to be a new event that | |
the callback wants the state machine to react to. This new event | |
may (depending on how the state machine is ran) get processed (and | |
this process typically repeats) until the state machine reaches a | |
terminal state. | |
""" | |
if state not in self._states: | |
raise excp.NotFound("Can not add a reaction to event '%s' for an" | |
" undefined state '%s'" % (event, state)) | |
assert six.callable(reaction), "Reaction callback must be callable" | |
if event not in self._states[state]['reactions']: | |
self._states[state]['reactions'][event] = (reaction, args, kwargs) | |
else: | |
raise excp.Duplicate("State '%s' reaction to event '%s'" | |
" already defined" % (state, event)) | |
def add_transition(self, start, end, event, on_enter=None, on_exit=None): | |
"""Adds an allowed transition from start -> end for the given event. | |
The on_enter and on_exit callbacks, if provided will be expected to | |
take two positional parameters, these being the state being exited (for | |
on_exit) or the state being entered (for on_enter) and a second | |
parameter which is the event that is being processed to cause this | |
transition. | |
""" | |
if start not in self._states: | |
raise excp.NotFound("Can not add a transition on event '%s' that" | |
" starts in a undefined state '%s'" % (event, | |
start)) | |
if end not in self._states: | |
raise excp.NotFound("Can not add a transition on event '%s' that" | |
" ends in a undefined state '%s'" % (event, | |
end)) | |
if on_enter is not None: | |
assert six.callable(on_enter), "On enter callback must be callable" | |
if on_exit is not None: | |
assert six.callable(on_exit), "On exit callback must be callable" | |
self._transitions[start][event] = _Jump(end, on_enter, on_exit) | |
def process_event(self, event): | |
"""Trigger a state change in response to the provided event.""" | |
current = self._current | |
if current is None: | |
raise NotInitialized("Can only process events after" | |
" being initialized (not before)") | |
if self._states[current.name]['terminal']: | |
raise excp.InvalidState("Can not transition from terminal" | |
" state '%s' on event '%s'" | |
% (current.name, event)) | |
if event not in self._transitions[current.name]: | |
raise excp.NotFound("Can not transition from state '%s' on" | |
" event '%s' (no defined transition)" | |
% (current.name, event)) | |
replacement = self._transitions[current.name][event] | |
if current.on_exit is not None: | |
current.on_exit(current.name, event) | |
if replacement.on_enter is not None: | |
replacement.on_enter(replacement.name, event) | |
self._current = replacement | |
return _Effect(self._states[replacement.name]['reactions'].get(event), | |
self._states[replacement.name]['terminal'], | |
self._states[replacement.name]['nested_machine']) | |
def _initialize(self, seen): | |
if self in seen: | |
return | |
if self._start_state not in self._states: | |
raise excp.NotFound("Can not start from a undefined" | |
" state '%s'" % (self._start_state)) | |
if self._states[self._start_state]['terminal']: | |
raise excp.InvalidState("Can not start from a terminal" | |
" state '%s'" % (self._start_state)) | |
self._current = _Jump(self._start_state, None, None) | |
seen.add(self) | |
for data in six.itervalues(self._states): | |
nested_machine = data['nested_machine'] | |
if nested_machine is not None: | |
nested_machine._initialize(seen) | |
def initialize(self): | |
"""Sets up the state machine (sets current state to start state...). | |
NOTE(harlowja): also initializes any nested state machines... | |
""" | |
self._initialize(set()) | |
def run(self, event, initialize=True): | |
"""Runs the state machine, using reactions only.""" | |
for transition in self.run_iter(event, initialize=initialize): | |
pass | |
def run_iter(self, event, initialize=True): | |
"""Runs the state machine, applying reactions or sent in events. | |
When the machine contains nested state machines, this is currently | |
the only way to run the root machine and the nested machine in a | |
manner that makes sense (since the active hierachy of machines is | |
currently kept on the runtime stack). This means those using FSM | |
running via process_event would need to maintain a similar stack, we | |
should probably look at refactoring this if others want to avoid | |
duplicating that same logic... | |
""" | |
if initialize: | |
self.initialize() | |
def process_event(event, machines): | |
"""Matches a event to a machine hierachy. | |
If the lowest level machine does not handle the event, then the | |
parent machine is referred to and so on, until there is only one | |
machine left which *must* handle the event. | |
The machine whose process_event does not throw invalid state or | |
not found exceptions is expected to be the machine (or its returned | |
nested machine) that should continue handling events. | |
""" | |
while True: | |
machine = machines[-1] | |
# TODO(harlowja): remove this, it is very very verbose... | |
LOG.debug("Trying machine '%s' to see if it can process" | |
" event '%s'", machine, event) | |
try: | |
return machine.process_event(event) | |
except (excp.InvalidState, excp.NotFound): | |
if len(machines) == 1: | |
raise | |
else: | |
machines.pop() | |
stack = [self] | |
while True: | |
old_state = stack[-1].current_state | |
effect = process_event(event, stack) | |
new_state = stack[-1].current_state | |
if effect.nested_machine is not None: | |
stack.append(effect.nested_machine) | |
try: | |
sent_event = yield (old_state, new_state) | |
except GeneratorExit: | |
break | |
if len(stack) == 1 and effect.terminal: | |
# Only allow the top level machine to actually terminate the | |
# execution, the rest of the nested machines must not handle | |
# events if they wish to have the root machine terminate... | |
break | |
if effect.reaction is None and sent_event is None: | |
raise excp.NotFound("Unable to progress since no reaction (or" | |
" sent event) has been made available in" | |
" new state '%s' (moved to from state '%s'" | |
" in response to event '%s')" | |
% (new_state, old_state, event)) | |
elif sent_event is not None: | |
event = sent_event | |
else: | |
cb, args, kwargs = effect.reaction | |
event = cb(old_state, new_state, event, *args, **kwargs) | |
def __contains__(self, state): | |
return state in self._states | |
@property | |
def states(self): | |
"""Returns the state names.""" | |
return list(six.iterkeys(self._states)) | |
@property | |
def events(self): | |
"""Returns how many events exist.""" | |
c = 0 | |
for state in six.iterkeys(self._states): | |
c += len(self._transitions[state]) | |
return c | |
def __iter__(self): | |
"""Iterates over (start, event, end) transition tuples.""" | |
for state in six.iterkeys(self._states): | |
for event, target in six.iteritems(self._transitions[state]): | |
yield (state, event, target.name) | |
def pformat(self, sort=True): | |
"""Pretty formats the state + transition table into a string. | |
NOTE(harlowja): the sort paramter can be provided to sort the states | |
and transitions by sort order; with it being provided as false the rows | |
will be iterated in addition order instead. | |
""" | |
def orderedkeys(data): | |
if sort: | |
return sorted(six.iterkeys(data)) | |
return list(six.iterkeys(data)) | |
tbl = prettytable.PrettyTable( | |
["Start", "Event", "End", "On Enter", "On Exit"]) | |
for state in orderedkeys(self._states): | |
prefix_markings = [] | |
if self.current_state == state: | |
prefix_markings.append("@") | |
postfix_markings = [] | |
if self.start_state == state: | |
postfix_markings.append("^") | |
if self._states[state]['terminal']: | |
postfix_markings.append("$") | |
if self._states[state]['nested_machine'] is not None: | |
postfix_markings.append("+") | |
pretty_state = "%s%s" % ("".join(prefix_markings), state) | |
if postfix_markings: | |
pretty_state += "[%s]" % "".join(postfix_markings) | |
if self._transitions[state]: | |
for event in orderedkeys(self._transitions[state]): | |
target = self._transitions[state][event] | |
row = [pretty_state, event, target.name] | |
if target.on_enter is not None: | |
try: | |
row.append(target.on_enter.__name__) | |
except AttributeError: | |
row.append(target.on_enter) | |
else: | |
row.append('') | |
if target.on_exit is not None: | |
try: | |
row.append(target.on_exit.__name__) | |
except AttributeError: | |
row.append(target.on_exit) | |
else: | |
row.append('') | |
tbl.add_row(row) | |
else: | |
tbl.add_row([pretty_state, "", "", "", ""]) | |
return tbl.get_string(print_empty=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment