Created
May 19, 2021 22:29
-
-
Save cmorss/c591a805fa0cfdcc2e8cd1f6b12e48fc to your computer and use it in GitHub Desktop.
Statemachine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Mark an ActiveRecord model as having a state machine. | |
# The state of the model defaults to being stored on an attribute | |
# called `state` or can be specified via the `column` arg to | |
# the `stateful` call. To set the default state for a | |
# class, set a default column value for `state` in the database | |
# or use `initial: true` when defining the initial state. | |
# | |
# Use Symbols for all keys and values in state definitions. | |
module Stateful | |
extend ActiveSupport::Concern | |
module Transitions | |
ENTERING = :entering | |
ENTERED = :entered | |
PERSISTED = :persisted | |
ALL = [ENTERING, ENTERED, PERSISTED].freeze | |
end | |
class StatefulError < RuntimeError | |
attr_reader :instance, :event, :state | |
def initialize(msg = nil, instance = nil, event = nil, state = nil) | |
super(msg) | |
@instance = instance | |
@event = event | |
@state = state | |
end | |
end | |
class TransitionError < StatefulError; end | |
class StateNotFoundError < StatefulError; end | |
class EventNotFoundError < StatefulError; end | |
class_methods do | |
def machine(column: :state) | |
@machine ||= Machine.new self, state_column: (column || :state) | |
end | |
def stateful(column: :status, &block) | |
machine(column: column).instance_eval(&block) | |
end | |
def states | |
machine.states | |
end | |
def register_listener(listener, opts = {}, &block) | |
machine(column: (opts[:column] || :state)).listener(listener, opts, &block) | |
end | |
end | |
ANY = Object.new | |
# Context to run the hooks in. Hooks are `entering`, `entered`, and `persisted` | |
Context = Struct.new(:instance, :event, :src, :dest, :args) do | |
def trigger(hooks, listeners, transition:) | |
hooks.values_at(dest, ANY).compact.flatten.each do |hook| | |
instance.instance_exec self, &hook | |
end | |
listeners.each do |listener| | |
if listener.applies?(self, transition: transition) | |
listener.trigger(self, transition: transition) | |
end | |
end | |
end | |
end | |
StateProperties = Struct.new(:name, :log_first_transition_only) do | |
def initialize(name, opts = {}) | |
self.name = name.to_s | |
self.log_first_transition_only = opts[:log_transition_time] == :first_only | |
end | |
end | |
Move = Struct.new(:src, :dest, :condition, :no_op) | |
# Events are things like `start` or `unload`. They cause the | |
# state machine to transition from one state to another. | |
# A corresponding method is created on the model for each event | |
# that is defined. The method is <event>!, e.g. car.start! | |
# Another method is also defined that is responsible for determining | |
# if the event is a valid transition. The predicate is can_<event>?, e.g. | |
# car.can_start? | |
class Event | |
def initialize(model, name) | |
@moves = Hash.new { |h, k| h[k] = [] } | |
# Define transition methods | |
model.public_send :define_method, "#{name}!" do |*args| | |
model.machine.fire self, name, *args | |
end | |
# Define transition predicates | |
model.public_send :define_method, "can_#{name}?" do |*args| | |
model.machine.transition_valid? self, name, *args | |
end | |
end | |
def any | |
ANY | |
end | |
# Get the destination move for the given current state and stateful instance. | |
# | |
# Conditional moves must evaluate to 'true' to be considered. If a conditional | |
# evaluates to 'false', it's as if the move does not exist and is not returned. | |
# | |
# @param [Symbol] current state | |
# @param [Stateful] instance that is transitioning | |
# @return [Move] move containing destination state of where to transition | |
def dest(current, instance) | |
@moves[current.to_sym].detect do |move| | |
return move if move.condition.nil? || move.condition&.call(instance) | |
end | |
any_move = @moves[any].first | |
return nil unless any_move | |
any_move if any_move.condition.nil? || any_move.condition&.call(instance) | |
end | |
def move(pair) | |
condition = pair.delete(:if) | |
no_op = pair.delete(:no_op) | |
Array(pair.keys.first).each do |s| | |
@moves[s] << Move.new(s, pair.values.first, condition, no_op) | |
end | |
end | |
end | |
class Listener | |
# Opts are { listener_name, method (optional), opt | |
def initialize(listener_name, options, &block) | |
@listener_name = listener_name | |
@block = block | |
@events = Array(options[:on]).compact | |
@transitions = options.slice( | |
Stateful::Transitions::ENTERING, | |
Stateful::Transitions::ENTERED, | |
Stateful::Transitions::PERSISTED | |
) | |
@listening_class = options[:listening_class] | |
@observed_class = options[:observed_class] | |
@observed_instance = options[:observed_instance] | |
@method = options[:method] | |
end | |
def applies?(context, transition:) | |
event_applies?(context) && | |
transition_applies?(transition, context) && | |
class_types_match?(context) | |
end | |
def trigger(context, transition:) | |
args = { | |
originator: context.instance, | |
transition: transition, | |
event: context.event, | |
source: context.src, | |
destination: context.dest | |
} | |
target = context.instance.public_send(@listener_name) | |
@block ? target.instance_exec(args, &@block) : target.public_send(@method, args) | |
end | |
private | |
def transition_applies?(transition, context) | |
@transitions.empty? || Array(@transitions[transition]).include?(context.dest) | |
end | |
def event_applies?(context) | |
(@events.empty? || @events.include?(context.event)) | |
end | |
def class_types_match?(context) | |
return true unless @observed_class && @listening_class | |
context.instance.is_a?(@observed_class) && | |
context.instance.public_send(@listener_name).is_a?(@listening_class) | |
end | |
end | |
class Machine | |
def initialize(model, opts = {}) | |
@state_column = opts[:state_column] || :state | |
@model = model | |
@state_properties = [] | |
@initial_state = nil | |
@entered = Hash.new { |h, k| h[k] = [] } | |
@entering = Hash.new { |h, k| h[k] = [] } | |
@events = Hash.new { |h, k| h[k] = Event.new(@model, k) } | |
@persisted = Hash.new { |h, k| h[k] = [] } | |
@allows = Hash.new { |h, k| h[k] = nil } | |
@listeners = [] | |
end | |
def entered(*names, &block) | |
validate_states(names) | |
munge(names).each { |n| @entered[n] << block } | |
end | |
def entering(*names, &block) | |
validate_states(names) | |
munge(names).each { |n| @entering[n] << block } | |
end | |
def fire(instance, name, *args) | |
unless @events.include? name | |
raise EventNotFoundError.new("No [#{name}] event.", instance, name) | |
end | |
src = instance.public_send(@state_column) || @initial_state | |
move = @events[name].dest(src, instance) | |
unless move | |
raise TransitionError.new("Cannot [#{name}] while in [#{src}]: #{instance}", instance, name, src) | |
end | |
# Make sure the additonal transition criteria was met | |
unless @allows[name].nil? || instance.instance_exec(&@allows[name]) | |
message = "Cannot [#{name}] while transition criteria not met: #{instance}" | |
raise TransitionError.new(message, instance, name, src) | |
end | |
if move.no_op | |
move.no_op.call(instance) if move.no_op.respond_to?(:call) | |
return | |
end | |
ctx = Context.new instance, name, src.to_sym, move.dest, args | |
ActiveRecord::Base.transaction do | |
ctx.trigger @entering, @listeners, transition: Stateful::Transitions::ENTERING | |
instance.public_send("#{@state_column}=", ctx.dest.to_s) | |
ctx.trigger @entered, @listeners, transition: Stateful::Transitions::ENTERED | |
# Log the time the transition occurred. | |
log_transition(ctx) | |
instance.save! | |
ctx.trigger @persisted, @listeners, transition: Stateful::Transitions::PERSISTED | |
end | |
instance | |
end | |
# Is the transition to the named state valid for this instance? | |
def transition_valid?(instance, name, *_args) | |
raise TransitionError, "No [#{name}] event." unless @events.include? name | |
src = instance[@state_column] || @initial_state | |
move = @events[name].dest(src, instance) | |
return false if move.nil? | |
return false unless @allows[name].nil? || instance.instance_exec(&@allows[name]) | |
true | |
end | |
def state_property(name) | |
@state_properties.detect { |prop| prop.name == name.to_s } | |
end | |
def on(*names, &block) | |
names.flatten.each { |n| @events[n].instance_eval(&block) } | |
end | |
# Additional state change criteria - block must return true or | |
# the state change will fail | |
def allows(*names, &block) | |
names.flatten.each { |n| @allows[n] = block } | |
end | |
def persisted(*names, &block) | |
validate_states(names) | |
munge(names).each { |n| @persisted[n] << block } | |
end | |
# We can spy on state transitions from associations if :inverse_of is specified in both | |
# directions. | |
def spy(name, opts = {}, &block) | |
reflection = @model.reflect_on_all_associations.detect { |a| a.name == name } | |
raise ArgumentError, "Could not find association #{name} on #{@model.name}" unless reflection | |
listener_name = reflection.inverse_of&.name | |
if listener_name.nil? | |
raise ArgumentError, 'Expected `:inverse_of` to be defined on both parent and child '\ | |
"associations when spying on #{name} with options [#{opts.inspect}]" | |
end | |
opts = opts.dup | |
opts[:listening_class] = @model | |
opts[:observed_class] = reflection.inverse_of.active_record | |
reflection.inverse_of.active_record.register_listener(listener_name, opts, &block) | |
end | |
def state(name, opts = {}) | |
name = name.to_s | |
@state_properties.push(StateProperties.new(name, opts)).uniq! | |
# Need to pull this into local vars for block access | |
column = @state_column | |
if opts[:initial] | |
raise ArgumentError, "Initial state of [#{@initial_state}] already specified" if @initial_state | |
@initial_state = name | |
# If the state wasn't set when creating then set it to the default | |
if @model.singleton_methods.include?(:after_initialize) | |
@model.after_initialize do |instance| | |
instance.public_send(:"#{column}=", name) unless instance.public_send(column) | |
end | |
end | |
end | |
return if @model.respond_to?(name) | |
@model.scope name, -> { where(column => name) } | |
@model.scope "not_#{name}", -> { where("#{column} != '#{name}'") } | |
initial = @initial_state | |
@model.public_send(:define_method, "#{name}?") do | |
name == (send(column)&.to_s || initial) | |
end | |
end | |
def states | |
@state_properties.map(&:name) | |
end | |
# Add a listener to state transitions: | |
# | |
# listener_name Required method on self to call to retrieve the listener instance | |
# method: Optional method name on listener to call when transition occurs | |
# on: [events] Optional events to filter by, i.e. invoke only if event fired is in the on: arg. | |
# Can be either an array of events or a single event. | |
# entered: [events] Optional transition filters. At least one event must be specified. | |
# entering: [events] | |
# persisted: [events] | |
# block Optional block to invoke on transition rather than calling a method on the listener. | |
# Block is invoked via an instance_exec on the listener. | |
def listener(listener_name, options = {}, &block) | |
@listeners << Listener.new(listener_name, options, &block) | |
end | |
private | |
def log_transition(context) | |
# Log the time the transition occurred. | |
attr = "#{context.dest}_at" | |
if context.instance.has_attribute?(attr) | |
if context.instance.public_send(attr).nil? || !state_property(context.dest).log_first_transition_only | |
context.instance.public_send("#{attr}=", DateTime.now) | |
end | |
end | |
# Log the number of times a transition occurred | |
attr = "#{context.dest}_count" | |
return unless context.instance.has_attribute?(attr) | |
context.instance.public_send("#{attr}=", (context.instance.public_send(attr) || 0) + 1) | |
end | |
# Puts state names into flattened array and converts an empty | |
# array to be ANY. This allows the `persisted` event | |
# to be used without any args which indicates it's to be called | |
# on ANY event transition. Or it can called with a list of specific | |
# names. | |
def munge(names) | |
names.flatten! | |
names.empty? ? [ANY] : names | |
end | |
# Check to see if the given array of states only contains states that | |
# are available in the state properties. Raise if not. | |
def validate_states(states) | |
state_syms = @state_properties.map { |x| x.name.to_sym } | |
states.each do |state| | |
unless state_syms.include?(state) | |
raise StateNotFoundError, "Unknown state: '#{state}'" | |
end | |
end | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment