Last active
June 12, 2016 05:43
-
-
Save thierryseegers/b41b03de2a878028f92b04de829e0e16 to your computer and use it in GitHub Desktop.
Markov chain implementation, generic on type and n-gram length.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
(C) Copyright Thierry Seegers 2016. Distributed under the following license: | |
Boost Software License - Version 1.0 - August 17th, 2003 | |
Permission is hereby granted, free of charge, to any person or organization | |
obtaining a copy of the software and accompanying documentation covered by | |
this license (the "Software") to use, reproduce, display, distribute, | |
execute, and transmit the Software, and to prepare derivative works of the | |
Software, and to permit third-parties to whom the Software is furnished to | |
do so, all subject to the following: | |
The copyright notices in the Software and this entire statement, including | |
the above license grant, this restriction and the following disclaimer, | |
must be included in all copies of the Software, in whole or in part, and | |
all derivative works of the Software, unless such copies or derivative | |
works are solely in the form of machine-executable object code generated by | |
a source language processor. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT | |
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE | |
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, | |
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |
DEALINGS IN THE SOFTWARE. | |
*/ | |
#include <cassert> | |
#include <cstdlib> | |
#include <iterator> | |
#include <map> | |
#include <queue> | |
#include <random> | |
#include <vector> | |
template<typename T, std::size_t N> | |
class markov_chain; | |
namespace detail | |
{ | |
// Because dereferencing involves computation, this iterator caches its value. | |
template<typename T, std::size_t N> | |
class const_iterator : public std::iterator<std::forward_iterator_tag, T> | |
{ | |
std::deque<T> prefixes; | |
markov_chain<T, N> const& chain; | |
mutable struct state_t | |
{ | |
bool cached : 1, end : 1; | |
} state; | |
mutable T suffix; | |
const_iterator(std::deque<T> prefixes, markov_chain<T, N> const& chain, bool end) : prefixes(prefixes), chain(chain), state({false, end}) | |
{ | |
assert(prefixes.size() == N); | |
} | |
friend class markov_chain<T, N>; | |
public: | |
const reference operator*() const | |
{ | |
assert(!state.end); | |
assert(prefixes.size() == N); | |
if(state.cached) | |
{ | |
return suffix; | |
} | |
else | |
{ | |
state.cached = true; | |
return suffix = chain.lookup(prefixes.cbegin(), prefixes.cend()); | |
} | |
} | |
const_iterator operator++(int) | |
{ | |
assert(!state.end); | |
assert(prefixes.size() == N); | |
const_iterator ci = *this; | |
++*this; | |
return ci; | |
} | |
const_iterator& operator++() | |
{ | |
assert(!state.end); | |
assert(prefixes.size() == N); | |
this->operator*(); | |
prefixes.pop_front(); | |
prefixes.push_back(suffix); | |
suffix = chain.lookup(prefixes.cbegin(), prefixes.cend()); | |
if(suffix == markov_chain<T, 1>::sentinel) | |
{ | |
state.end = true; | |
prefixes.assign(N, markov_chain<T, 1>::sentinel); | |
} | |
assert(prefixes.size() == N); | |
return *this; | |
} | |
bool operator==(const_iterator const& other) | |
{ | |
return state.end == other.state.end && prefixes == other.prefixes; | |
} | |
bool operator!=(const_iterator const& other) | |
{ | |
return !(*this == other); | |
} | |
}; | |
} | |
template<typename T, std::size_t N> | |
class markov_chain | |
{ | |
public: | |
using const_iterator = detail::const_iterator<T, N>; | |
markov_chain() | |
{} | |
template<typename InputIt> | |
markov_chain(InputIt first, InputIt last) | |
{ | |
initialize(first, last, *this); | |
} | |
void add(typename std::deque<T>::const_iterator begin, typename std::deque<T>::const_iterator end, T const& suffix) | |
{ | |
assert(distance(begin, end) == N); | |
prefix_to_suffix[*begin].add(begin + 1, end, suffix); | |
} | |
T lookup(typename std::deque<T>::const_iterator begin, typename std::deque<T>::const_iterator end) const | |
{ | |
assert(distance(begin, end) == N); | |
return prefix_to_suffix.at(*begin).lookup(begin + 1, end); | |
} | |
const_iterator cbegin() const | |
{ | |
return const_iterator{ std::deque<T>(N, markov_chain<T, 1>::sentinel), *this, false }; | |
} | |
const_iterator cend() const | |
{ | |
return const_iterator{ std::deque<T>(N, markov_chain<T, 1>::sentinel), *this, true }; | |
} | |
private: | |
std::map<T, markov_chain<T, N - 1>> prefix_to_suffix; | |
}; | |
template<typename T> | |
class markov_chain<T, 1> | |
{ | |
public: | |
using const_iterator = detail::const_iterator<T, 1>; | |
static T sentinel; | |
markov_chain() | |
{} | |
template<typename InputIt> | |
markov_chain(InputIt first, InputIt last) | |
{ | |
initialize(first, last, *this); | |
} | |
void add(typename std::deque<T>::const_iterator begin, typename std::deque<T>::const_iterator end, T const& suffix) | |
{ | |
assert(distance(begin, end) == 1); | |
prefix_to_suffix[*begin].push_back(suffix); | |
} | |
T lookup(typename std::deque<T>::const_iterator begin, typename std::deque<T>::const_iterator end) const | |
{ | |
assert(distance(begin, end) == 1); | |
auto const& suffixes = prefix_to_suffix.at(*begin); | |
return suffixes[std::uniform_int_distribution<std::size_t>{0, suffixes.size() - 1}(random_engine)]; | |
} | |
const_iterator cbegin() const | |
{ | |
return const_iterator{ std::deque<T>(1, markov_chain<T, 1>::sentinel), *this, false }; | |
} | |
const_iterator cend() const | |
{ | |
return const_iterator{ std::deque<T>(1, markov_chain<T, 1>::sentinel), *this, true }; | |
} | |
private: | |
static std::mt19937 random_engine; | |
std::map<T, std::vector<T>> prefix_to_suffix; | |
}; | |
template<typename T> | |
T markov_chain<T, 1>::sentinel = T{}; | |
template<typename T> | |
std::mt19937 markov_chain<T, 1>::random_engine = std::mt19937(std::random_device{}()); | |
template<typename InputIt, typename T, std::size_t N> | |
void initialize(InputIt first, InputIt last, markov_chain<T, N>& chain) | |
{ | |
std::deque<T> prefixes(N, markov_chain<T, 1>::sentinel); | |
for(; first != last; ++first) | |
{ | |
chain.add(prefixes.cbegin(), prefixes.cend(), *first); | |
prefixes.pop_front(); | |
prefixes.push_back(*first); | |
} | |
chain.add(prefixes.cbegin(), prefixes.cend(), markov_chain<T, 1>::sentinel); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment