Skip to content

Instantly share code, notes, and snippets.

@cmorss
Created May 19, 2021 22:29
Show Gist options
  • Save cmorss/c591a805fa0cfdcc2e8cd1f6b12e48fc to your computer and use it in GitHub Desktop.
Save cmorss/c591a805fa0cfdcc2e8cd1f6b12e48fc to your computer and use it in GitHub Desktop.
Statemachine
# Mark an ActiveRecord model as having a state machine.
# The state of the model defaults to being stored on an attribute
# called `state` or can be specified via the `column` arg to
# the `stateful` call. To set the default state for a
# class, set a default column value for `state` in the database
# or use `initial: true` when defining the initial state.
#
# Use Symbols for all keys and values in state definitions.
module Stateful
extend ActiveSupport::Concern
module Transitions
ENTERING = :entering
ENTERED = :entered
PERSISTED = :persisted
ALL = [ENTERING, ENTERED, PERSISTED].freeze
end
class StatefulError < RuntimeError
attr_reader :instance, :event, :state
def initialize(msg = nil, instance = nil, event = nil, state = nil)
super(msg)
@instance = instance
@event = event
@state = state
end
end
class TransitionError < StatefulError; end
class StateNotFoundError < StatefulError; end
class EventNotFoundError < StatefulError; end
class_methods do
def machine(column: :state)
@machine ||= Machine.new self, state_column: (column || :state)
end
def stateful(column: :status, &block)
machine(column: column).instance_eval(&block)
end
def states
machine.states
end
def register_listener(listener, opts = {}, &block)
machine(column: (opts[:column] || :state)).listener(listener, opts, &block)
end
end
ANY = Object.new
# Context to run the hooks in. Hooks are `entering`, `entered`, and `persisted`
Context = Struct.new(:instance, :event, :src, :dest, :args) do
def trigger(hooks, listeners, transition:)
hooks.values_at(dest, ANY).compact.flatten.each do |hook|
instance.instance_exec self, &hook
end
listeners.each do |listener|
if listener.applies?(self, transition: transition)
listener.trigger(self, transition: transition)
end
end
end
end
StateProperties = Struct.new(:name, :log_first_transition_only) do
def initialize(name, opts = {})
self.name = name.to_s
self.log_first_transition_only = opts[:log_transition_time] == :first_only
end
end
Move = Struct.new(:src, :dest, :condition, :no_op)
# Events are things like `start` or `unload`. They cause the
# state machine to transition from one state to another.
# A corresponding method is created on the model for each event
# that is defined. The method is <event>!, e.g. car.start!
# Another method is also defined that is responsible for determining
# if the event is a valid transition. The predicate is can_<event>?, e.g.
# car.can_start?
class Event
def initialize(model, name)
@moves = Hash.new { |h, k| h[k] = [] }
# Define transition methods
model.public_send :define_method, "#{name}!" do |*args|
model.machine.fire self, name, *args
end
# Define transition predicates
model.public_send :define_method, "can_#{name}?" do |*args|
model.machine.transition_valid? self, name, *args
end
end
def any
ANY
end
# Get the destination move for the given current state and stateful instance.
#
# Conditional moves must evaluate to 'true' to be considered. If a conditional
# evaluates to 'false', it's as if the move does not exist and is not returned.
#
# @param [Symbol] current state
# @param [Stateful] instance that is transitioning
# @return [Move] move containing destination state of where to transition
def dest(current, instance)
@moves[current.to_sym].detect do |move|
return move if move.condition.nil? || move.condition&.call(instance)
end
any_move = @moves[any].first
return nil unless any_move
any_move if any_move.condition.nil? || any_move.condition&.call(instance)
end
def move(pair)
condition = pair.delete(:if)
no_op = pair.delete(:no_op)
Array(pair.keys.first).each do |s|
@moves[s] << Move.new(s, pair.values.first, condition, no_op)
end
end
end
class Listener
# Opts are { listener_name, method (optional), opt
def initialize(listener_name, options, &block)
@listener_name = listener_name
@block = block
@events = Array(options[:on]).compact
@transitions = options.slice(
Stateful::Transitions::ENTERING,
Stateful::Transitions::ENTERED,
Stateful::Transitions::PERSISTED
)
@listening_class = options[:listening_class]
@observed_class = options[:observed_class]
@observed_instance = options[:observed_instance]
@method = options[:method]
end
def applies?(context, transition:)
event_applies?(context) &&
transition_applies?(transition, context) &&
class_types_match?(context)
end
def trigger(context, transition:)
args = {
originator: context.instance,
transition: transition,
event: context.event,
source: context.src,
destination: context.dest
}
target = context.instance.public_send(@listener_name)
@block ? target.instance_exec(args, &@block) : target.public_send(@method, args)
end
private
def transition_applies?(transition, context)
@transitions.empty? || Array(@transitions[transition]).include?(context.dest)
end
def event_applies?(context)
(@events.empty? || @events.include?(context.event))
end
def class_types_match?(context)
return true unless @observed_class && @listening_class
context.instance.is_a?(@observed_class) &&
context.instance.public_send(@listener_name).is_a?(@listening_class)
end
end
class Machine
def initialize(model, opts = {})
@state_column = opts[:state_column] || :state
@model = model
@state_properties = []
@initial_state = nil
@entered = Hash.new { |h, k| h[k] = [] }
@entering = Hash.new { |h, k| h[k] = [] }
@events = Hash.new { |h, k| h[k] = Event.new(@model, k) }
@persisted = Hash.new { |h, k| h[k] = [] }
@allows = Hash.new { |h, k| h[k] = nil }
@listeners = []
end
def entered(*names, &block)
validate_states(names)
munge(names).each { |n| @entered[n] << block }
end
def entering(*names, &block)
validate_states(names)
munge(names).each { |n| @entering[n] << block }
end
def fire(instance, name, *args)
unless @events.include? name
raise EventNotFoundError.new("No [#{name}] event.", instance, name)
end
src = instance.public_send(@state_column) || @initial_state
move = @events[name].dest(src, instance)
unless move
raise TransitionError.new("Cannot [#{name}] while in [#{src}]: #{instance}", instance, name, src)
end
# Make sure the additonal transition criteria was met
unless @allows[name].nil? || instance.instance_exec(&@allows[name])
message = "Cannot [#{name}] while transition criteria not met: #{instance}"
raise TransitionError.new(message, instance, name, src)
end
if move.no_op
move.no_op.call(instance) if move.no_op.respond_to?(:call)
return
end
ctx = Context.new instance, name, src.to_sym, move.dest, args
ActiveRecord::Base.transaction do
ctx.trigger @entering, @listeners, transition: Stateful::Transitions::ENTERING
instance.public_send("#{@state_column}=", ctx.dest.to_s)
ctx.trigger @entered, @listeners, transition: Stateful::Transitions::ENTERED
# Log the time the transition occurred.
log_transition(ctx)
instance.save!
ctx.trigger @persisted, @listeners, transition: Stateful::Transitions::PERSISTED
end
instance
end
# Is the transition to the named state valid for this instance?
def transition_valid?(instance, name, *_args)
raise TransitionError, "No [#{name}] event." unless @events.include? name
src = instance[@state_column] || @initial_state
move = @events[name].dest(src, instance)
return false if move.nil?
return false unless @allows[name].nil? || instance.instance_exec(&@allows[name])
true
end
def state_property(name)
@state_properties.detect { |prop| prop.name == name.to_s }
end
def on(*names, &block)
names.flatten.each { |n| @events[n].instance_eval(&block) }
end
# Additional state change criteria - block must return true or
# the state change will fail
def allows(*names, &block)
names.flatten.each { |n| @allows[n] = block }
end
def persisted(*names, &block)
validate_states(names)
munge(names).each { |n| @persisted[n] << block }
end
# We can spy on state transitions from associations if :inverse_of is specified in both
# directions.
def spy(name, opts = {}, &block)
reflection = @model.reflect_on_all_associations.detect { |a| a.name == name }
raise ArgumentError, "Could not find association #{name} on #{@model.name}" unless reflection
listener_name = reflection.inverse_of&.name
if listener_name.nil?
raise ArgumentError, 'Expected `:inverse_of` to be defined on both parent and child '\
"associations when spying on #{name} with options [#{opts.inspect}]"
end
opts = opts.dup
opts[:listening_class] = @model
opts[:observed_class] = reflection.inverse_of.active_record
reflection.inverse_of.active_record.register_listener(listener_name, opts, &block)
end
def state(name, opts = {})
name = name.to_s
@state_properties.push(StateProperties.new(name, opts)).uniq!
# Need to pull this into local vars for block access
column = @state_column
if opts[:initial]
raise ArgumentError, "Initial state of [#{@initial_state}] already specified" if @initial_state
@initial_state = name
# If the state wasn't set when creating then set it to the default
if @model.singleton_methods.include?(:after_initialize)
@model.after_initialize do |instance|
instance.public_send(:"#{column}=", name) unless instance.public_send(column)
end
end
end
return if @model.respond_to?(name)
@model.scope name, -> { where(column => name) }
@model.scope "not_#{name}", -> { where("#{column} != '#{name}'") }
initial = @initial_state
@model.public_send(:define_method, "#{name}?") do
name == (send(column)&.to_s || initial)
end
end
def states
@state_properties.map(&:name)
end
# Add a listener to state transitions:
#
# listener_name Required method on self to call to retrieve the listener instance
# method: Optional method name on listener to call when transition occurs
# on: [events] Optional events to filter by, i.e. invoke only if event fired is in the on: arg.
# Can be either an array of events or a single event.
# entered: [events] Optional transition filters. At least one event must be specified.
# entering: [events]
# persisted: [events]
# block Optional block to invoke on transition rather than calling a method on the listener.
# Block is invoked via an instance_exec on the listener.
def listener(listener_name, options = {}, &block)
@listeners << Listener.new(listener_name, options, &block)
end
private
def log_transition(context)
# Log the time the transition occurred.
attr = "#{context.dest}_at"
if context.instance.has_attribute?(attr)
if context.instance.public_send(attr).nil? || !state_property(context.dest).log_first_transition_only
context.instance.public_send("#{attr}=", DateTime.now)
end
end
# Log the number of times a transition occurred
attr = "#{context.dest}_count"
return unless context.instance.has_attribute?(attr)
context.instance.public_send("#{attr}=", (context.instance.public_send(attr) || 0) + 1)
end
# Puts state names into flattened array and converts an empty
# array to be ANY. This allows the `persisted` event
# to be used without any args which indicates it's to be called
# on ANY event transition. Or it can called with a list of specific
# names.
def munge(names)
names.flatten!
names.empty? ? [ANY] : names
end
# Check to see if the given array of states only contains states that
# are available in the state properties. Raise if not.
def validate_states(states)
state_syms = @state_properties.map { |x| x.name.to_sym }
states.each do |state|
unless state_syms.include?(state)
raise StateNotFoundError, "Unknown state: '#{state}'"
end
end
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment