Skip to content

Instantly share code, notes, and snippets.

@rioki
Created April 6, 2024 15:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rioki/10d009f788408c803d2a32b88188c8d6 to your computer and use it in GitHub Desktop.
Save rioki/10d009f788408c803d2a32b88188c8d6 to your computer and use it in GitHub Desktop.
State Machine
// State Machine
// Copyright 2022 Sean Farrell <sean.farrell@rioki.org>
//
// This program is free software. It comes without any warranty, to
// the extent permitted by applicable law. You can redistribute it
// and/or modify it under the terms of the Do What The Fuck You Want
// To Public License, Version 2, as published by Sam Hocevar. See
// http://www.wtfpl.net/ for more details.
#pragma once
#include <cassert>
#include <optional>
#include <functional>
#include <map>
namespace sm
{
template <typename StateEnum>
class StateMachine
{
public:
StateMachine() noexcept = default;
StateMachine(StateEnum initial_state) noexcept
: state(initial_state) {}
StateEnum get_state() const noexcept
{
return state;
}
void change_state(StateEnum new_state) noexcept
{
auto i = exit_functions.find(state);
if (i != end(exit_functions))
{
auto& func = i->second;
assert(func);
func();
}
state = new_state;
auto j = enter_functions.find(state);
if (j != end(enter_functions))
{
auto& func = j->second;
assert(func);
func();
}
}
void queue_state(StateEnum new_state) noexcept
{
next_state = new_state;
}
void tick() noexcept
{
if (next_state)
{
change_state(*next_state);
next_state = std::nullopt;
}
auto i = tick_functions.find(state);
if (i != end(tick_functions))
{
auto& func = i->second;
assert(func);
func();
}
}
void on_enter(StateEnum state, const std::function<void ()>& func)
{
assert(func);
assert(enter_functions.find(state) == end(enter_functions));
enter_functions[state] = func;
}
void on_tick(StateEnum state, const std::function<void ()>& func)
{
assert(func);
assert(tick_functions.find(state) == end(tick_functions));
tick_functions[state] = func;
}
void on_exit(StateEnum state, const std::function<void ()>& func)
{
assert(func);
assert(exit_functions.find(state) == end(exit_functions));
exit_functions[state] = func;
}
private:
StateEnum state = static_cast<StateEnum>(0);
std::optional<StateEnum> next_state = std::nullopt;
std::map<StateEnum, std::function<void ()>> tick_functions;
std::map<StateEnum, std::function<void ()>> enter_functions;
std::map<StateEnum, std::function<void ()>> exit_functions;
StateMachine(const StateMachine&) = delete;
StateMachine& operator = (const StateMachine&) = delete;
};
}
// State Machine
// Copyright 2022 Sean Farrell <sean.farrell@rioki.org>
//
// This program is free software. It comes without any warranty, to
// the extent permitted by applicable law. You can redistribute it
// and/or modify it under the terms of the Do What The Fuck You Want
// To Public License, Version 2, as published by Sam Hocevar. See
// http://www.wtfpl.net/ for more details.
#include <gtest/gtest.h>
#include "StateMachine.h"
enum class State
{
INIT,
STATE1,
STATE2,
END
};
TEST(StateMachine, create)
{
auto sm = sm::StateMachine<State>(State::INIT);
EXPECT_EQ(State::INIT, sm.get_state());
}
TEST(StateMachine, create_default_zero)
{
auto sm = sm::StateMachine<State>();
EXPECT_EQ(State::INIT, sm.get_state());
}
TEST(StateMachine, change_state)
{
auto sm = sm::StateMachine<State>(State::INIT);
EXPECT_EQ(State::INIT, sm.get_state());
sm.change_state(State::STATE1);
EXPECT_EQ(State::STATE1, sm.get_state());
sm.change_state(State::STATE2);
EXPECT_EQ(State::STATE2, sm.get_state());
sm.change_state(State::END);
EXPECT_EQ(State::END, sm.get_state());
}
TEST(StateMachine, queue_state)
{
auto sm = sm::StateMachine<State>(State::INIT);
EXPECT_EQ(State::INIT, sm.get_state());
sm.queue_state(State::STATE1);
EXPECT_EQ(State::INIT, sm.get_state());
sm.tick();
EXPECT_EQ(State::STATE1, sm.get_state());
sm.queue_state(State::STATE2);
EXPECT_EQ(State::STATE1, sm.get_state());
sm.tick();
EXPECT_EQ(State::STATE2, sm.get_state());
sm.queue_state(State::END);
EXPECT_EQ(State::STATE2, sm.get_state());
sm.tick();
EXPECT_EQ(State::END, sm.get_state());
}
TEST(StateMachine, processing_intructions)
{
auto sm = sm::StateMachine<State>(State::INIT);
auto state1_tick = 0u;
sm.on_tick(State::STATE1, [&] () {
state1_tick++;
});
auto state2_tick = 0u;
sm.on_tick(State::STATE2, [&] () {
state2_tick++;
});
sm.tick();
sm.change_state(State::STATE1);
sm.tick();
sm.tick();
sm.change_state(State::STATE2);
sm.tick();
sm.tick();
sm.tick();
sm.change_state(State::END);
sm.tick();
EXPECT_EQ(State::END, sm.get_state());
EXPECT_EQ(2, state1_tick);
EXPECT_EQ(3, state2_tick);
}
TEST(StateMachine, processing_intructions_with_queue_state)
{
auto sm = sm::StateMachine<State>(State::INIT);
auto state1_tick = 0u;
sm.on_tick(State::STATE1, [&] () {
state1_tick++;
});
auto state2_tick = 0u;
sm.on_tick(State::STATE2, [&] () {
state2_tick++;
});
sm.tick();
sm.queue_state(State::STATE1);
sm.tick();
sm.tick();
sm.queue_state(State::STATE2);
sm.tick();
sm.tick();
sm.tick();
sm.queue_state(State::END);
sm.tick();
EXPECT_EQ(State::END, sm.get_state());
EXPECT_EQ(2, state1_tick);
EXPECT_EQ(3, state2_tick);
}
TEST(StateMachine, pre_and_post_intructions)
{
auto sm = sm::StateMachine<State>(State::INIT);
auto state1_enter = 0u;
sm.on_enter(State::STATE1, [&] () {
state1_enter++;
});
auto state1_exit = 0u;
sm.on_exit(State::STATE1, [&] () {
state1_exit++;
});
auto state2_enter = 0u;
sm.on_enter(State::STATE2, [&] () {
state2_enter++;
});
auto state2_exit = 0u;
sm.on_exit(State::STATE2, [&] () {
state2_exit++;
});
sm.tick();
sm.change_state(State::STATE1);
sm.tick();
sm.tick();
sm.change_state(State::STATE2);
sm.tick();
sm.tick();
sm.tick();
sm.change_state(State::END);
sm.tick();
EXPECT_EQ(State::END, sm.get_state());
EXPECT_EQ(1, state1_enter);
EXPECT_EQ(1, state1_exit);
EXPECT_EQ(1, state1_enter);
EXPECT_EQ(1, state1_exit);
}
TEST(StateMachine, pre_and_post_intructions_queue_state)
{
auto sm = sm::StateMachine<State>(State::INIT);
auto state1_enter = 0u;
sm.on_enter(State::STATE1, [&] () {
state1_enter++;
});
auto state1_exit = 0u;
sm.on_exit(State::STATE1, [&] () {
state1_exit++;
});
auto state2_enter = 0u;
sm.on_enter(State::STATE2, [&] () {
state2_enter++;
});
auto state2_exit = 0u;
sm.on_exit(State::STATE2, [&] () {
state2_exit++;
});
sm.tick();
sm.queue_state(State::STATE1);
sm.tick();
sm.tick();
sm.queue_state(State::STATE2);
sm.tick();
sm.tick();
sm.tick();
sm.queue_state(State::END);
sm.tick();
EXPECT_EQ(State::END, sm.get_state());
EXPECT_EQ(1, state1_enter);
EXPECT_EQ(1, state1_exit);
EXPECT_EQ(1, state1_enter);
EXPECT_EQ(1, state1_exit);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment