Skip to content

Instantly share code, notes, and snippets.

Last active January 14, 2024 13:41
Show Gist options
  • Save mbant/67875e0464cd9d1402413532e3244261 to your computer and use it in GitHub Desktop.
Save mbant/67875e0464cd9d1402413532e3244261 to your computer and use it in GitHub Desktop.
Beam Search - C++
using namespace std;
constexpr array<char,26> vocab = {'a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z'};
vector<vector<double>> log_p_language_model;
double lm(const char current, const char next){ return log_p_language_model[current-'a'][next-'a']; }
// we want to decode a string of length 'length',
// starting for a given character
// and computing the next character by interfacing with our very sophisticated Language Model
// lm(current,next) that will return us the probability that the 'next' character follows the 'current'
// finally output the best string and its log_probability
// we can do it greedily (not optimal but easier to code)
pair<string,double> greedy_decoding(char start, size_t length)
string decoded(length,'0');
decoded[0] = start;
double log_p = 0;
for(int i=1; i<length; i++)
double max_p { numeric_limits<double>::lowest() };
char best_c;
for(auto const& c : vocab)
double p = lm(decoded[i-1],c);
if( p > max_p )
max_p = p;
best_c = c;
decoded[i] = best_c;
log_p += max_p;
return make_pair(decoded,log_p);
struct Beam{
double log_p;
string s;
Beam() : log_p(0.), s("") { }
Beam(double _lp, string _s) : log_p(_lp), s(_s) { }
bool beam_comp(const Beam& b1, const Beam& b2){ return b1.log_p < b2.log_p; }
// Now let's code beam search
/* expand_beams(size_t beam_idx, vector<Beam>& beams, vector<Beam>& next_beams, size_t beam_size)
* Expand beams[j] and push the top beam_size into next_beams
void expand_beam(size_t beam_idx, vector<Beam>& beams, vector<Beam>& next_beams, size_t beam_size)
vector<Beam> next_possible_beams{};
for(auto const& c : vocab)
next_possible_beams.push_back( Beam( beams[beam_idx].log_p + lm(beams[beam_idx].s.back() , c) , beams[beam_idx].s + c ) );
// check these 26 elements for the top 'beam_size' ones
// using the whole path probability
for(int k=0; k<beam_size; ++k)
// next_beams[j*beam_size+k] = next_possible_beams.back();
/* expand_all_beams(vector<Beam>& beams, vector<Beam>& next_beams, size_t beam_size)
* Expand all the `beams` into `next_beams` and heapify it
void expand_all_beams(vector<Beam>& beams, vector<Beam>& next_beams, size_t beam_size)
for(int j=0; j<beams.size(); j++)
/* expand_all_beams_moremem(vector<Beam>& beams, vector<Beam>& next_beams, size_t beam_size)
* Expand all the `beams` into `next_beams` and heapify it
* If memory isn't a concern, use this. It skips finding the top `beam_size` from every expansion
* and simply store all intermediate results into heapified `next_beams`; selection unchanged
vector<Beam> expand_all_beams_moremem(vector<Beam>& beams, size_t beam_size)
vector<Beam> next_beams{};
for(int j=0; j<beams.size(); j++)
for(auto const& c : vocab)
next_beams.push_back( Beam( beams[j].log_p + lm(beams[j].s.back() , c) , beams[j].s + c ) );
return next_beams;
/* select_best_beams(vector<Beam>& beams, vector<Beam>& next_beams, size_t beam_size)
* Selects from the `beam_size*beam_size` elements in the heapyfied `next_beams`
* and insert replaces `beams` with the best `beam_size` ones
void select_best_beams(vector<Beam>& beams, vector<Beam>& next_beams, size_t beam_size)
// select the best 'beam_size' ones
vector<char> last_states(beam_size);
for(int k=0; k<beam_size; ++k)
// optimization (Asif) : if 2 paths collide
// you can discard the least probable since you know it's not gonna be the max ever
while( k > 0 && // the most probable goes in regardless
// next_beams.size() < (beam_size - k) && // we need at least beam_size - k remaining to fill all the slots -- or do we? at worst less beams for next
next_beams.back().s.back()) != last_states.end()) // merge condition
beams.push_back( next_beams.back() );
/* beam_decoding(char start, size_t length, size_t beam_size)
* Uses beam serch to find the optimal string, given:
* - a language model `lm`
* - the desired `lenght`
* - a `beam_size`
* for beam_size = 1 degenerates to greedy decoding
* higher beam_size trades computational complexity for optimality.
* beam_size = infinity becomes Viterbi decoding, which will enumerate all possible
* strings and select the best one
* This function returns a tuple with the selected string and its path-logprobability
pair<string,double> beam_decoding(char start, size_t length, size_t beam_size)
// one beam only at the start
vector<Beam> beams(1);
beams[0].s = start;
for(int i=1; i<length; i++)
vector<Beam> next_beams{};
// get the max between the beams
int max_idx = 0;
for(int i=1; i<beams.size(); i++)
if( beam_comp(beams[max_idx],beams[i]) )
max_idx = i;
return make_pair(beams[max_idx].s,beams[max_idx].log_p);
/* fast_beam_search(char start, size_t length, size_t beam_size)
* Less memory efficient version on beam_decoding(*),
* see beam_decoding docstring for details
* possibly slightly faster, but same asymptotic complexity
pair<string,double> fast_beam_decoding(char start, size_t length, size_t beam_size)
// one beam only at the start
vector<Beam> beams(1);
beams[0].s = start;
for(int i=1; i<length; i++)
auto next_beams = expand_all_beams_moremem(beams,beam_size);
// get the max between the beams
int max_idx = 0;
for(int i=1; i<beams.size(); i++)
if( beam_comp(beams[max_idx],beams[i]) )
max_idx = i;
return make_pair(beams[max_idx].s,beams[max_idx].log_p);
int main(int argc, char const *argv[])
// Init the very sophisticated language model
log_p_language_model = vector<vector<double>>(26);
for( auto& row : log_p_language_model)
row = vector<double>(26,numeric_limits<double>::lowest());
log_p_language_model[0][1] = log(0.55); // a->b = 55%
log_p_language_model[0][2] = log(0.45); // a->c = 45%
log_p_language_model[1][0] = log(0.15); // b->a = 15% -- remaining low
for(int i=1; i<26; i++)
log_p_language_model[1][i] = log((1.-log_p_language_model[1][0])/25);
log_p_language_model[2][0] = log(0.4); // c->a = 40% -- remaining low
for(int i=1; i<26; i++)
log_p_language_model[2][i] = log((1.-log_p_language_model[2][0])/25);
// Ideally in this setting greedy decoding should oscillate from a to b
// while beam search should be able to "discover" that the path through c gives higher probability overall
// Define beam size and length of the decoded string
size_t beam_size = 5;
size_t length = 30;
// Get starting timepoint
auto start = std::chrono::high_resolution_clock::now();
auto gd = greedy_decoding('a',length);
cout << "Best (greedily) decoded string: " << gd.first << "\n\twhich had a probability of " << gd.second << endl;
// should be ababababab...
// Get ending timepoint
auto stop = std::chrono::high_resolution_clock::now();
// Get duration. Substart timepoints to
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
cout << "Time taken by Greedy Decoding: "
<< duration.count() << " microseconds" << endl << endl;
start = std::chrono::high_resolution_clock::now();
auto bd = beam_decoding('a',length,beam_size);
cout << "Best (beam_search) decoded string: " << bd.first << "\n\twhich had a probability of " << bd.second << endl;
// should be acacacacac... (with a final a or final b) and a log probability greater than the geedily decoded string
stop = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
cout << "Time taken by Beam Decoding: "
<< duration.count() << " microseconds" << endl << endl;
start = std::chrono::high_resolution_clock::now();
auto fbd = fast_beam_decoding('a',length,beam_size);
cout << "Best (fast_beam_search) decoded string: " << fbd.first << "\n\twhich had a probability of " << fbd.second << endl;
// should be the same as the memory-efficient version
stop = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
cout << "Time taken by Fast Beam Decoding: "
<< duration.count() << " microseconds" << endl << endl;
return 0;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment