Skip to content

Instantly share code, notes, and snippets.

@khvorov
Created October 11, 2015 02:58
Show Gist options
  • Save khvorov/a994fb6fafe69a804e14 to your computer and use it in GitHub Desktop.
Save khvorov/a994fb6fafe69a804e14 to your computer and use it in GitHub Desktop.
Memoization in C++
// memoize: F, Args... ==> R result
// hash(&F, args,,,) ==> R result
#include <boost/any.hpp>
#include <boost/functional/hash.hpp>
#include <exception>
#include <initializer_list>
#include <iostream>
#include <type_traits>
#include <unordered_map>
#include <utility>
template <typename... Args>
std::size_t hash(Args&&... args)
{
std::size_t seed = 0;
(void) std::initializer_list<int> { (boost::hash_combine(seed, std::forward<Args>(args)), 0)... };
return seed;
}
std::unordered_map<std::size_t, boost::any> m_memomap;
template <typename F, typename... Args>
using func_return_type = decltype(std::declval<F>()(std::declval<Args>()...));
template <typename F, typename... Args>
auto memoize(F&& f, Args&&... args) -> const std::decay_t<func_return_type<F, Args...>> &
{
auto h = hash(&f, std::forward<Args>(args)...);
auto i = m_memomap.find(h);
if (i == m_memomap.end())
{
auto pi = m_memomap.emplace(h, f(std::forward<Args>(args)...));
i = pi.first;
}
return *boost::any_cast<func_return_type<F, Args...>>(&(i->second));
}
// tests
int foo(int a, int b, int c) { return a + b + c; }
double bar(double a, int b, int c) { return a * b * c; }
std::string buzz() { return std::string("string lah"); }
unsigned long long fact(unsigned long n)
{
if (n == 0 || n == 1)
return 1;
return memoize(fact, n - 1) * n;
}
unsigned long long fibo(unsigned long long n)
{
if (n <= 1)
return n;
return memoize(fibo, n - 1) + memoize(fibo, n - 2);
}
int main()
{
try
{
std::cout << memoize(foo, 1, 2, 3) << std::endl;
std::cout << memoize(foo, 1, 1, 1) << std::endl;
std::cout << memoize(bar, 3.14, 2, 3) << std::endl;
std::cout << memoize(foo, 1, 2, 3) << std::endl;
std::cout << memoize(buzz) << std::endl;
std::cout << fact(1) << ", " << fact(30) << ", " << fact(5) << ", " << fact(10) << std::endl;
std::cout << fibo(1) << ", " << fibo(30) << ", " << fibo(5) << ", " << fibo(10) << std::endl;
}
catch (std::exception & e)
{
std::cerr << "caught: " << e.what() << std::endl;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment