Skip to content

Instantly share code, notes, and snippets.

@thierryseegers
Last active June 12, 2016 05:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thierryseegers/b41b03de2a878028f92b04de829e0e16 to your computer and use it in GitHub Desktop.
Save thierryseegers/b41b03de2a878028f92b04de829e0e16 to your computer and use it in GitHub Desktop.
Markov chain implementation, generic on type and n-gram length.
/*
(C) Copyright Thierry Seegers 2016. Distributed under the following license:
Boost Software License - Version 1.0 - August 17th, 2003
Permission is hereby granted, free of charge, to any person or organization
obtaining a copy of the software and accompanying documentation covered by
this license (the "Software") to use, reproduce, display, distribute,
execute, and transmit the Software, and to prepare derivative works of the
Software, and to permit third-parties to whom the Software is furnished to
do so, all subject to the following:
The copyright notices in the Software and this entire statement, including
the above license grant, this restriction and the following disclaimer,
must be included in all copies of the Software, in whole or in part, and
all derivative works of the Software, unless such copies or derivative
works are solely in the form of machine-executable object code generated by
a source language processor.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
*/
#include <cassert>
#include <cstdlib>
#include <iterator>
#include <map>
#include <queue>
#include <random>
#include <vector>
template<typename T, std::size_t N>
class markov_chain;
namespace detail
{
// Because dereferencing involves computation, this iterator caches its value.
template<typename T, std::size_t N>
class const_iterator : public std::iterator<std::forward_iterator_tag, T>
{
std::deque<T> prefixes;
markov_chain<T, N> const& chain;
mutable struct state_t
{
bool cached : 1, end : 1;
} state;
mutable T suffix;
const_iterator(std::deque<T> prefixes, markov_chain<T, N> const& chain, bool end) : prefixes(prefixes), chain(chain), state({false, end})
{
assert(prefixes.size() == N);
}
friend class markov_chain<T, N>;
public:
const reference operator*() const
{
assert(!state.end);
assert(prefixes.size() == N);
if(state.cached)
{
return suffix;
}
else
{
state.cached = true;
return suffix = chain.lookup(prefixes.cbegin(), prefixes.cend());
}
}
const_iterator operator++(int)
{
assert(!state.end);
assert(prefixes.size() == N);
const_iterator ci = *this;
++*this;
return ci;
}
const_iterator& operator++()
{
assert(!state.end);
assert(prefixes.size() == N);
this->operator*();
prefixes.pop_front();
prefixes.push_back(suffix);
suffix = chain.lookup(prefixes.cbegin(), prefixes.cend());
if(suffix == markov_chain<T, 1>::sentinel)
{
state.end = true;
prefixes.assign(N, markov_chain<T, 1>::sentinel);
}
assert(prefixes.size() == N);
return *this;
}
bool operator==(const_iterator const& other)
{
return state.end == other.state.end && prefixes == other.prefixes;
}
bool operator!=(const_iterator const& other)
{
return !(*this == other);
}
};
}
template<typename T, std::size_t N>
class markov_chain
{
public:
using const_iterator = detail::const_iterator<T, N>;
markov_chain()
{}
template<typename InputIt>
markov_chain(InputIt first, InputIt last)
{
initialize(first, last, *this);
}
void add(typename std::deque<T>::const_iterator begin, typename std::deque<T>::const_iterator end, T const& suffix)
{
assert(distance(begin, end) == N);
prefix_to_suffix[*begin].add(begin + 1, end, suffix);
}
T lookup(typename std::deque<T>::const_iterator begin, typename std::deque<T>::const_iterator end) const
{
assert(distance(begin, end) == N);
return prefix_to_suffix.at(*begin).lookup(begin + 1, end);
}
const_iterator cbegin() const
{
return const_iterator{ std::deque<T>(N, markov_chain<T, 1>::sentinel), *this, false };
}
const_iterator cend() const
{
return const_iterator{ std::deque<T>(N, markov_chain<T, 1>::sentinel), *this, true };
}
private:
std::map<T, markov_chain<T, N - 1>> prefix_to_suffix;
};
template<typename T>
class markov_chain<T, 1>
{
public:
using const_iterator = detail::const_iterator<T, 1>;
static T sentinel;
markov_chain()
{}
template<typename InputIt>
markov_chain(InputIt first, InputIt last)
{
initialize(first, last, *this);
}
void add(typename std::deque<T>::const_iterator begin, typename std::deque<T>::const_iterator end, T const& suffix)
{
assert(distance(begin, end) == 1);
prefix_to_suffix[*begin].push_back(suffix);
}
T lookup(typename std::deque<T>::const_iterator begin, typename std::deque<T>::const_iterator end) const
{
assert(distance(begin, end) == 1);
auto const& suffixes = prefix_to_suffix.at(*begin);
return suffixes[std::uniform_int_distribution<std::size_t>{0, suffixes.size() - 1}(random_engine)];
}
const_iterator cbegin() const
{
return const_iterator{ std::deque<T>(1, markov_chain<T, 1>::sentinel), *this, false };
}
const_iterator cend() const
{
return const_iterator{ std::deque<T>(1, markov_chain<T, 1>::sentinel), *this, true };
}
private:
static std::mt19937 random_engine;
std::map<T, std::vector<T>> prefix_to_suffix;
};
template<typename T>
T markov_chain<T, 1>::sentinel = T{};
template<typename T>
std::mt19937 markov_chain<T, 1>::random_engine = std::mt19937(std::random_device{}());
template<typename InputIt, typename T, std::size_t N>
void initialize(InputIt first, InputIt last, markov_chain<T, N>& chain)
{
std::deque<T> prefixes(N, markov_chain<T, 1>::sentinel);
for(; first != last; ++first)
{
chain.add(prefixes.cbegin(), prefixes.cend(), *first);
prefixes.pop_front();
prefixes.push_back(*first);
}
chain.add(prefixes.cbegin(), prefixes.cend(), markov_chain<T, 1>::sentinel);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment