Skip to content

Instantly share code, notes, and snippets.

@ecatmur
Forked from lcapaldo/gist:2065956
Last active December 12, 2015 00:19
Show Gist options
  • Save ecatmur/4683098 to your computer and use it in GitHub Desktop.
Save ecatmur/4683098 to your computer and use it in GitHub Desktop.
State monad, with nary lift. Type construction is collapsed to allow inference.
#include <functional>
#include <utility>
#include <iostream>
namespace hask {
struct Unit {};
template<typename S, typename V> struct State {
std::function<std::pair<S, V>(S)> runState;
using state_type = S;
using value_type = V;
};
template<typename S, typename V> State<S, V> ret(const V& v) {
return {([=](S s) { return std::make_pair(s, v); })};
}
template<typename S> State<S, S> get() {
return {([](S s) { return std::make_pair(s, s); })};
}
template<typename S> State<S, Unit> put(const S &s) {
return {([=](S) { return std::make_pair(s, Unit{}); })};
}
template<typename S, typename A, typename F /* State<S, R> (A) */,
typename R = typename std::result_of<F(A)>::type::value_type>
State<S, R> bind(State<S, A> a, F &&f) {
return {([=](S init) -> std::pair<S, R> {
auto p = a.runState(init);
return f(p.second).runState(p.first);
})};
}
template<typename S, typename A, typename R>
State<S, R> seq(State<S, A> a, State<S, R> r) {
return bind<S, A>(a, [=](A) { return r; });
}
template<typename S, typename F /* R (As...) */, typename... As,
typename R = typename std::result_of<F(As...)>::type>
State<S, R> lift(F &&f, State<S, As>...);
template<typename S, typename F,
typename R = typename std::result_of<F()>::type>
State<S, R> lift(F &&f) {
return ret<S>(f()); // lifting a 0ary function gives ret
}
template<typename S, typename F, typename A, typename... Bs,
typename R = typename std::result_of<F(A, Bs...)>::type>
State<S, R> lift(F &&f, State<S, A> sa, State<S, Bs>... sbs) {
// lift a n+1-ary function by recursion
return bind<S, A>(sa, [=](A a) { return lift<S>([=](Bs... bs) {
return f(a, bs...); }, sbs...); });
}
}
int main() {
// auto incrS = hask::bind(
// hask::get<int>(), [](int x) { return hask::ret<int>(x + 1); });
// auto res = incrS.runState(7);
// std::cout << res.first << ',' << res.second << '\n';
auto r2 = hask::lift<int>([](int x, int y) { return x > y; },
hask::get<int>(), hask::ret<int>(100)).runState(101);
std::cout << r2.first << ',' << std::boolalpha << r2.second << '\n';
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment