-
-
Save ecatmur/4683098 to your computer and use it in GitHub Desktop.
State monad, with nary lift. Type construction is collapsed to allow inference.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <functional> | |
#include <utility> | |
#include <iostream> | |
namespace hask { | |
struct Unit {}; | |
template<typename S, typename V> struct State { | |
std::function<std::pair<S, V>(S)> runState; | |
using state_type = S; | |
using value_type = V; | |
}; | |
template<typename S, typename V> State<S, V> ret(const V& v) { | |
return {([=](S s) { return std::make_pair(s, v); })}; | |
} | |
template<typename S> State<S, S> get() { | |
return {([](S s) { return std::make_pair(s, s); })}; | |
} | |
template<typename S> State<S, Unit> put(const S &s) { | |
return {([=](S) { return std::make_pair(s, Unit{}); })}; | |
} | |
template<typename S, typename A, typename F /* State<S, R> (A) */, | |
typename R = typename std::result_of<F(A)>::type::value_type> | |
State<S, R> bind(State<S, A> a, F &&f) { | |
return {([=](S init) -> std::pair<S, R> { | |
auto p = a.runState(init); | |
return f(p.second).runState(p.first); | |
})}; | |
} | |
template<typename S, typename A, typename R> | |
State<S, R> seq(State<S, A> a, State<S, R> r) { | |
return bind<S, A>(a, [=](A) { return r; }); | |
} | |
template<typename S, typename F /* R (As...) */, typename... As, | |
typename R = typename std::result_of<F(As...)>::type> | |
State<S, R> lift(F &&f, State<S, As>...); | |
template<typename S, typename F, | |
typename R = typename std::result_of<F()>::type> | |
State<S, R> lift(F &&f) { | |
return ret<S>(f()); // lifting a 0ary function gives ret | |
} | |
template<typename S, typename F, typename A, typename... Bs, | |
typename R = typename std::result_of<F(A, Bs...)>::type> | |
State<S, R> lift(F &&f, State<S, A> sa, State<S, Bs>... sbs) { | |
// lift a n+1-ary function by recursion | |
return bind<S, A>(sa, [=](A a) { return lift<S>([=](Bs... bs) { | |
return f(a, bs...); }, sbs...); }); | |
} | |
} | |
int main() { | |
// auto incrS = hask::bind( | |
// hask::get<int>(), [](int x) { return hask::ret<int>(x + 1); }); | |
// auto res = incrS.runState(7); | |
// std::cout << res.first << ',' << res.second << '\n'; | |
auto r2 = hask::lift<int>([](int x, int y) { return x > y; }, | |
hask::get<int>(), hask::ret<int>(100)).runState(101); | |
std::cout << r2.first << ',' << std::boolalpha << r2.second << '\n'; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment