Last active
November 13, 2023 19:15
-
-
Save ttsuki/94caf749e101c4b54786c0a6c864f710 to your computer and use it in GitHub Desktop.
Walker's Alias Method (C++)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Walker's Alias Method (C++) | |
#include <utility> | |
#include <initializer_list> | |
#include <vector> | |
#include <random> | |
#include <stdexcept> | |
template <class Label, class Probability = size_t> | |
class weighted_random_choice // implements Walker's Alias Method | |
{ | |
struct slot | |
{ | |
Label lower{}; | |
Label higher{}; | |
Probability threshold{}; | |
}; | |
std::vector<slot> table_; | |
Probability total_; | |
public: | |
/// ctor | |
weighted_random_choice( | |
std::initializer_list<std::pair<Label, Probability>> table) | |
: weighted_random_choice(table.begin(), table.end()) { } | |
/// ctor | |
template <class Alloc> | |
weighted_random_choice( | |
const std::vector<std::pair<Label, Probability>, Alloc>& table) | |
: weighted_random_choice(table.begin(), table.end()) {} | |
/// ctor | |
template <class IteratorForLabelProbabilityPair> | |
weighted_random_choice( | |
IteratorForLabelProbabilityPair first, | |
IteratorForLabelProbabilityPair last) | |
{ | |
struct iterator_pair | |
{ | |
IteratorForLabelProbabilityPair beg_, end_; | |
auto begin() const noexcept { return beg_; } | |
auto end() const noexcept { return end_; } | |
} const source{first, last}; | |
// counts the number of elements and the total | |
size_t n = 0; | |
Probability total = 0; | |
for (const auto& [label, prob] : source) | |
{ | |
++n; | |
total += prob; | |
} | |
if (total == 0) throw std::invalid_argument("The total must not be 0."); | |
if (static_cast<Probability>(n * total) / total != n) throw std::invalid_argument("Probability type cannot represent total*n."); | |
// builds table | |
std::vector<std::pair<Label, Probability>> lower_deck, higher_deck; | |
for (const auto& [label, prob] : source) | |
if (prob != 0) | |
(prob * n <= total ? lower_deck : higher_deck).emplace_back(label, prob * n); | |
this->table_.reserve(n); | |
this->total_ = total; | |
while (!lower_deck.empty()) | |
{ | |
auto [lower, threshold] = std::move(lower_deck.back()); | |
lower_deck.pop_back(); | |
if (threshold == total) | |
{ | |
this->table_.push_back({lower, lower, threshold}); | |
} | |
else | |
{ | |
auto [higher, rent] = std::move(higher_deck.back()); | |
higher_deck.pop_back(); | |
this->table_.push_back({lower, higher, threshold}); | |
rent -= total - threshold; | |
(rent <= total ? lower_deck : higher_deck).emplace_back(higher, rent); | |
} | |
} | |
} | |
/// choose one randomly | |
template <class RandomEngine = std::default_random_engine> | |
Label choose_one(RandomEngine& random) const | |
{ | |
const Probability mini = 0, maxi = static_cast<Probability>(table_.size() * total_); | |
const auto index = std::uniform_int_distribution<Probability>{mini, maxi - 1}(random); | |
const auto row = index / total_; | |
const auto col = index % total_; | |
return col < table_[row].threshold ? table_[row].lower : table_[row].higher; | |
} | |
}; | |
// Test | |
#include <iostream> | |
#include <string> | |
#include <map> | |
int main() | |
{ | |
std::default_random_engine random{std::random_device()()}; | |
const weighted_random_choice<std::string> table | |
{ | |
{"apple", 100}, | |
{"banana", 200}, | |
{"chocolate", 300}, | |
}; | |
std::map<std::string, size_t> result; | |
for (size_t i = 0; i < 1000000; i++) | |
++result[table.choose_one(random)]; | |
for (auto [label, count] : result) | |
std::cout << label << ": " << count << "\n"; | |
// outputs: | |
// apple: 166729 | |
// banana : 333348 | |
// chocolate : 499923 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment