Skip to content

Instantly share code, notes, and snippets.

@rca
Last active September 29, 2023 16:41
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save rca/7421319 to your computer and use it in GitHub Desktop.
Save rca/7421319 to your computer and use it in GitHub Desktop.
Python Finite State Machine implementation; logic mostly extracted from https://github.com/kmmbvnr/django-fsm
"""
Finite State Machine
This FSM implementation is extracted from the django-fsm package and licensed
under the same BSD-like license at:
https://github.com/kmmbvnr/django-fsm/blob/master/LICENSE
Basic usage:
```
from fsm import State, transition
class Container(object):
# some states
OFFLINE = 'offline'
ONLINE = 'online'
# the state machine
state_machine = State()
def __init__(self, container_id):
self.state = self.OFFLINE
self.container_id = container_id
self.log_proc = None
@transition(state_machine, source=OFFLINE, target=ONLINE)
def offline(self):
self.log_proc = attach_to_container(self.container_id)
@exception_transition((Disconnected,), target=OFFLINE)
@transition(state_machine, source=ONLINE, target=ONLINE)
def online(self):
for log in self.log_proc.get_logs():
self.do_something(log)
def loop(self):
getattr(self, self.state)()
```
"""
from collections import defaultdict
from functools import wraps
class Signal(object):
def __init__(self):
self.connections = []
def connect(self, receiver, sender=None):
self.connections.append((receiver, sender))
def send(self, sender, instance, name, source, target):
for _receiver, _sender in self.connections:
if _sender is None or _sender == sender:
_receiver(sender, instance=instance, name=name, source=source, target=target)
class TransitionNotAllowed(Exception):
"""Raise when a transition is not allowed"""
class State(object):
name = 'state'
def __init__(self, state='offline'):
self.state = state
self.transitions = []
def __cmp__(self, other):
return cmp(str(self), str(other))
def __repr__(self):
return '<State: {}>'.format(self.state)
def __str__(self):
return self.state
class StateMachine(object):
def __init__(self, state=None, target_state=None):
self.state = state
self.target_state = target_state
def loop(self, target_state=None):
# optionally set the desired target_state when calling run
self.target_state = target_state or self.target_state
# field to keep track of the state we transition to below. it is set
# to None at first so that "state loops" are run, i.e. a series of
# states that should be run which loop back to the initial state (e.g.
# ONLINE -> ONLINE_TASK -> ONLINE_TASK2 -> ONLINE)
new_state = None
# loop until the desired state or TERMINATED is reached
while new_state != self.target_state and self.state != self.TERMINATED:
getattr(self, self.state)()
new_state = self.state
pre_transition = Signal()
post_transition = Signal()
class FSMMeta(object):
"""
Models methods transitions meta information
"""
def __init__(self, field):
self.field = field
self.transitions = defaultdict()
self.conditions = defaultdict()
def add_transition(self, source, target, conditions=[]):
if source in self.transitions:
raise AssertionError('Duplicate transition for %s state' % source)
self.transitions[source] = target
self.conditions[source] = conditions
def _get_state_field(self, instance):
return self.field
def current_state(self, instance):
"""
Return current state of Django model
"""
field_name = self._get_state_field(instance).name
return getattr(instance, field_name)
def next_state(self, instance):
curr_state = self.current_state(instance)
result = None
try:
#machine_name = getattr(self.field, '_machine_name', '')
#print '{}: transitions={}, curr_state={}'.format(machine_name, self.transitions, curr_state)
result = self.transitions[str(curr_state)]
except KeyError:
result = self.transitions['*']
return result
def has_transition(self, instance):
"""
Lookup if any transition exists from current model state
"""
return self.transitions.has_key(str(self.current_state(instance))) or self.transitions.has_key('*')
def conditions_met(self, instance):
"""
Check if all conditions has been met
"""
state = self.current_state(instance)
if state not in self.conditions:
state = '*'
if all(map(lambda f: f(instance), self.conditions.get(state, []))):
return True
return False
def to_next_state(self, instance):
"""
Switch to next state
"""
field_name = self._get_state_field(instance).name
state = self.next_state(instance)
if state:
instance.__dict__[field_name] = state
def transition(field, source='*', target=None, save=False, conditions=[]):
"""
Method decorator for mark allowed transition
Set target to None if current state need to be validated and not
changed after function call
"""
# pylint: disable=C0111
def inner_transition(func):
if not hasattr(func, '_fsm_meta'):
setattr(func, '_fsm_meta', FSMMeta(field=field))
@wraps(func)
def _change_state(instance, *args, **kwargs):
meta = func._fsm_meta
if not (meta.has_transition(instance) and meta.conditions_met(instance)):
raise TransitionNotAllowed("Can't switch from state '%s' using method '%s'" % (meta.current_state(instance), func.func_name))
source_state = meta.current_state(instance)
pre_transition.send(
sender = instance.__class__,
instance = instance,
name = func.func_name,
source = source_state,
target = meta.next_state(instance))
result = func(instance, *args, **kwargs)
meta.to_next_state(instance)
if save:
instance.save()
post_transition.send(
sender = instance.__class__,
instance = instance,
name = func.func_name,
source = source_state,
target = meta.current_state(instance))
return result
else:
_change_state = func
if isinstance(source, (list, tuple)):
for state in source:
func._fsm_meta.add_transition(state, target, conditions)
else:
func._fsm_meta.add_transition(source, target, conditions)
if field:
field.transitions.append(_change_state)
return _change_state
return inner_transition
def exception_transition(exceptions, target, reraise=True):
"""
Decorator to set the state to the given target when the given exceptions are raised.
"""
def exception_transition_inner(func):
@wraps(func)
def exception_transition_wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except exceptions, exc:
self.state = target
if reraise:
raise
return exception_transition_wrapper
return exception_transition_inner
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment