Skip to content

Instantly share code, notes, and snippets.

@ttsuki
Last active November 13, 2023 19:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ttsuki/94caf749e101c4b54786c0a6c864f710 to your computer and use it in GitHub Desktop.
Save ttsuki/94caf749e101c4b54786c0a6c864f710 to your computer and use it in GitHub Desktop.
Walker's Alias Method (C++)
// Walker's Alias Method (C++)
#include <utility>
#include <initializer_list>
#include <vector>
#include <random>
#include <stdexcept>
template <class Label, class Probability = size_t>
class weighted_random_choice // implements Walker's Alias Method
{
struct slot
{
Label lower{};
Label higher{};
Probability threshold{};
};
std::vector<slot> table_;
Probability total_;
public:
/// ctor
weighted_random_choice(
std::initializer_list<std::pair<Label, Probability>> table)
: weighted_random_choice(table.begin(), table.end()) { }
/// ctor
template <class Alloc>
weighted_random_choice(
const std::vector<std::pair<Label, Probability>, Alloc>& table)
: weighted_random_choice(table.begin(), table.end()) {}
/// ctor
template <class IteratorForLabelProbabilityPair>
weighted_random_choice(
IteratorForLabelProbabilityPair first,
IteratorForLabelProbabilityPair last)
{
struct iterator_pair
{
IteratorForLabelProbabilityPair beg_, end_;
auto begin() const noexcept { return beg_; }
auto end() const noexcept { return end_; }
} const source{first, last};
// counts the number of elements and the total
size_t n = 0;
Probability total = 0;
for (const auto& [label, prob] : source)
{
++n;
total += prob;
}
if (total == 0) throw std::invalid_argument("The total must not be 0.");
if (static_cast<Probability>(n * total) / total != n) throw std::invalid_argument("Probability type cannot represent total*n.");
// builds table
std::vector<std::pair<Label, Probability>> lower_deck, higher_deck;
for (const auto& [label, prob] : source)
if (prob != 0)
(prob * n <= total ? lower_deck : higher_deck).emplace_back(label, prob * n);
this->table_.reserve(n);
this->total_ = total;
while (!lower_deck.empty())
{
auto [lower, threshold] = std::move(lower_deck.back());
lower_deck.pop_back();
if (threshold == total)
{
this->table_.push_back({lower, lower, threshold});
}
else
{
auto [higher, rent] = std::move(higher_deck.back());
higher_deck.pop_back();
this->table_.push_back({lower, higher, threshold});
rent -= total - threshold;
(rent <= total ? lower_deck : higher_deck).emplace_back(higher, rent);
}
}
}
/// choose one randomly
template <class RandomEngine = std::default_random_engine>
Label choose_one(RandomEngine& random) const
{
const Probability mini = 0, maxi = static_cast<Probability>(table_.size() * total_);
const auto index = std::uniform_int_distribution<Probability>{mini, maxi - 1}(random);
const auto row = index / total_;
const auto col = index % total_;
return col < table_[row].threshold ? table_[row].lower : table_[row].higher;
}
};
// Test
#include <iostream>
#include <string>
#include <map>
int main()
{
std::default_random_engine random{std::random_device()()};
const weighted_random_choice<std::string> table
{
{"apple", 100},
{"banana", 200},
{"chocolate", 300},
};
std::map<std::string, size_t> result;
for (size_t i = 0; i < 1000000; i++)
++result[table.choose_one(random)];
for (auto [label, count] : result)
std::cout << label << ": " << count << "\n";
// outputs:
// apple: 166729
// banana : 333348
// chocolate : 499923
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment