Skip to content

Instantly share code, notes, and snippets.

@splinterofchaos
Created October 27, 2012 17:56
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save splinterofchaos/3965514 to your computer and use it in GitHub Desktop.
Save splinterofchaos/3965514 to your computer and use it in GitHub Desktop.
Monad
#include <memory>
#include <iostream>
#include <utility>
#include <algorithm>
#include <iterator>
struct sequence_tag {};
struct pointer_tag {};
template< class X >
X category( ... );
template< class S >
auto category( const S& s ) -> decltype( std::begin(s), sequence_tag() );
template< class Ptr >
auto category( const Ptr& p ) -> decltype( *p, p==nullptr, pointer_tag() );
template< class T > struct Category {
using type = decltype( category<T>(std::declval<T>()) );
};
template< class R, class ... X > struct Category< R(&)(X...) > {
using type = R(&)(X...);
};
template< class T >
using Cat = typename Category<T>::type;
template< class... > struct Functor;
template< class F, class FX, class Fun=Functor< Cat<FX> > >
auto fmap( F&& f, FX&& fx )
-> decltype( Fun::fmap( std::declval<F>(), std::declval<FX>() ) )
{
return Fun::fmap( std::forward<F>(f), std::forward<FX>(fx) );
}
template< class F, class G >
struct Composition {
F f;
G g;
template< class X >
auto operator () ( X&& x ) -> decltype( f(g(std::declval<X>())) ) {
return f(g(std::forward<X>(x)));
}
};
// General case: composition
template< class Function > struct Functor<Function> {
template< class F, class G, class C = Composition<F,G> >
static C fmap( F f, G g ) {
C( std::move(f), std::move(g) );
}
};
template<> struct Functor< sequence_tag > {
template< class F, template<class...>class S, class X,
class R = typename std::result_of<F(X)>::type >
static S<R> fmap( F&& f, const S<X>& s ) {
S<R> r;
r.reserve( s.size() );
std::transform( std::begin(s), std::end(s),
std::back_inserter(r),
std::forward<F>(f) );
return r;
}
};
template<> struct Functor< pointer_tag > {
template< class F, template<class...>class Ptr, class X,
class R = typename std::result_of<F(X)>::type >
static Ptr<R> fmap( F&& f, const Ptr<X>& p )
{
return p != nullptr
? Ptr<R>( new R( std::forward<F>(f)(*p) ) )
: nullptr;
}
};
template< class ... > struct Monad;
template< class F, class M, class ...N, class Mo=Monad<Cat<M>> >
auto mbind( F&& f, M&& m, N&& ...n )
-> decltype( Mo::mbind(std::declval<F>(),
std::declval<M>(),std::declval<N>()...) )
{
return Mo::mbind( std::forward<F>(f),
std::forward<M>(m), std::forward<N>(n)... );
}
template< class F, class M, class ...N, class Mo=Monad<Cat<M>> >
auto mdo( F&& f, M&& m )
-> decltype( Mo::mdo(std::declval<F>(), std::declval<M>()) )
{
return Mo::mdo( std::forward<F>(f), std::forward<M>(m) );
}
// The first template argument must be explicit!
template< class M, class X, class Mo = Monad<Cat<M>> >
M mreturn( X&& x ) {
return Mo::template mreturn<M>( std::forward<X>(x) );
}
template< template<class...>class M, class X, class Mo = Monad<Cat<M<X>>> >
M<X> mreturn( const X& x ) {
return Mo::template mreturn<M<X>>( x );
}
// Also has explicit template argument.
template< class M, class Mo = Monad<Cat<M>> >
M mfail() {
return Mo::template mfail<M>();
}
template< > struct Monad< pointer_tag > {
template< class F, template<class...>class Ptr, class X,
class R = typename std::result_of<F(X)>::type >
static R mbind( F&& f, const Ptr<X>& p ) {
return p ? std::forward<F>(f)( *p ) : nullptr;
}
template< class F, template<class...>class Ptr,
class X, class Y,
class R = typename std::result_of<F(X,Y)>::type >
static R mbind( F&& f, const Ptr<X>& p, const Ptr<Y>& q ) {
return p and q ? std::forward<F>(f)( *p, *q ) : nullptr;
}
template< template< class... > class M, class X, class Y >
static M<Y> mdo( const M<X>& mx, const M<Y>& my ) {
return mx ? (my ? mreturn<M<Y>>(*my) : nullptr)
: nullptr;
}
template< class M, class X >
static M mreturn( X&& x ) {
using Y = typename M::element_type;
return M( new Y(std::forward<X>(x)) );
}
template< class M >
static M mfail() { return nullptr; }
};
template< > struct Monad< sequence_tag > {
template< class F, template<class...>class S, class X,
class R = typename std::result_of<F(X)>::type >
static R mbind( F&& f, const S<X>& xs ) {
R r;
for( const X& x : xs ) {
auto ys = std::forward<F>(f)( x );
std::move( std::begin(ys), std::end(ys), std::back_inserter(r) );
}
return r;
}
template< class F, template<class...>class S,
class X, class Y,
class R = typename std::result_of<F(X,Y)>::type >
static R mbind( F&& f, const S<X>& xs, const S<Y>& ys ) {
R r;
for( const X& x : xs ) {
for( const Y& y : ys ) {
auto zs = std::forward<F>(f)( x, y );
std::move( std::begin(zs), std::end(zs),
std::back_inserter(r) );
}
}
return r;
}
template< template< class... > class S, class X, class Y >
static S<Y> mdo( const S<X>& mx, const S<Y>& my ) {
// Note: This is not a strictly correct definition.
// It should return my concatenated to itself for every element of mx.
return mx.size() ? my : S<Y>{};
}
template< class S, class X >
static S mreturn( X&& x ) {
return S{ std::forward<X>(x) }; // Construct an S of one element.
}
template< class S >
static S mfail() { return S{}; }
};
template< class M >
M addM( const M& a, const M& b ) {
return mbind (
[&]( int x ) {
return mbind (
[=]( int y ) { return mreturn<M>(x+y); },
b
);
},
a
);
}
template< class M >
M addM2( const M& a, const M& b ) {
return mbind (
[&]( int x ) {
return fmap (
[=]( int y ) { return x + y; },
b
);
},
a
);
}
template< class M, class F >
auto operator >>= ( M&& m, F&& f )
-> decltype( mbind(std::declval<F>(),std::declval<M>()) )
{
return mbind( std::forward<F>(f), std::forward<M>(m) );
}
template< class M, class F >
auto operator >> ( M&& m, F&& f )
-> decltype( mdo(std::declval<M>(),std::declval<F>()) )
{
return mdo( std::forward<M>(m), std::forward<F>(f) );
}
template< class F, template<class...>class M,
class X, class Y,
class R = typename std::result_of<F(X,Y)>::type >
M<R> liftM( F&& f, const M<X>& a, const M<Y>& b ) {
return a >>= [&]( const X& x ) {
return b >>= [&]( const Y& y ) {
return mreturn<M>( std::forward<F>(f)(x,y) );
};
};
};
/*
* guard<M>(b) = (return True) or mfail().
*
* guard prematurely halts an execution based on some bool, b. Note that:
* p >> q = q
* mfail() >> p = mfail()
* nullptr >> p = nullptr -- where p is a unique_ptr.
* {} >> v = {} -- where v is a vector.
*/
template< template< class... > class M >
M<bool> guard( bool b ) {
return b ? mreturn<M>(b) : mfail<M<bool>>();
}
/*
* The above version of guard creates a junk Monad. This may be costly.
* This version is an optimal shorthand for guard(b) >> m.
*/
template< template< class... > class M, class F,
class R = typename std::result_of<F()>::type >
M<R> guard( bool b, F&& f ) {
return b ? mreturn<M>( std::forward<F>(f)() ) : mfail<M<R>>();
}
template< template< class... > class M, class X >
M< std::pair<X,X> > uniquePairs( const M<X>& m ) {
return mbind (
[]( int x, int y ) -> M< std::pair<X,X> > {
// This is a very Haskell-like use of guard.
return guard<M>( x != y ) >> mreturn<M>( std::make_pair(x,y) );
}, m, m
);
}
/* alias for mreturn<unique_ptr> */
template< class X >
auto Just( X&& x ) -> decltype( mreturn<std::unique_ptr>(std::declval<X>()) ) {
return mreturn<std::unique_ptr>( std::forward<X>(x) );
}
#include <cmath>
// Safe square root.
std::unique_ptr<float> sqrt( float x ) {
// The more optimized C++-guard.
return guard<std::unique_ptr>( x >= 0, [x]{ return std::sqrt(x); } );
// Equivalently,
return x >= 0 ? Just( std::sqrt(x) ) : nullptr;
}
// Safe quadratic root.
std::unique_ptr<std::pair<float,float>> qroot( float a, float b, float c ) {
return fmap (
[=]( float r /*root*/ ) {
return std::make_pair( (-b + r)/(2*a), (-b - r)/(2*a) );
},
sqrt( b*b - 4*a*c )
);
}
template< class X, class Y >
std::ostream& operator << ( std::ostream& os, const std::pair<X,Y>& p ) {
os << '(' << p.first << ',' << p.second << ')';
return os;
}
template< class X >
std::ostream& operator << ( std::ostream& os, const std::unique_ptr<X>& p ) {
if( p )
os << "Just " << *p;
else
os << "Nothing";
return os;
}
int main() {
std::unique_ptr<int> p( new int(5) );
auto f = []( int x ) { return Just(-x); };
std::unique_ptr<int> q = mbind( f, p );
std::cout << "q = " << q << std::endl;
std::cout << "p+q = " << addM2( p, q ) << std::endl;
std::cout << "p+q = " << liftM( std::plus<int>(), p, q ) << std::endl;
std::vector<int> v={1,2}, w={3,4};
std::cout << "v+w = { ";
auto vw = addM(v,w);
std::copy (
std::begin(vw), std::end(vw),
std::ostream_iterator<int>(std::cout, " ")
);
std::cout << '}' << std::endl;
{
std::vector<int> v = {1,2,3};
using V = std::vector<std::pair<int,int>>;
auto ps = uniquePairs( v );
std::cout << "Unique pairs of [1,2,3]:\n\t";
for( const auto& p : ps )
std::cout << p << ' ';
std::cout << std::endl;
std::cout << "Unique pairs of Just 5:\n\t" << uniquePairs(p) << std::endl;
}
std::cout << "The quadratic root of (1,3,-4) = " << qroot(1,3,-4) << std::endl;
std::cout << "The quadratic root of (1,0,4) = " << qroot(1,0,4) << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment