Skip to content

Instantly share code, notes, and snippets.

@reuben
Created October 16, 2019 07:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save reuben/ee38a3608c3542e6515c6b9afdb5d670 to your computer and use it in GitHub Desktop.
Save reuben/ee38a3608c3542e6515c6b9afdb5d670 to your computer and use it in GitHub Desktop.
diff --git a/native_client/ctcdecode/path_trie.cpp b/native_client/ctcdecode/path_trie.cpp
index 51f75ff3..69f8b6f7 100644
--- a/native_client/ctcdecode/path_trie.cpp
+++ b/native_client/ctcdecode/path_trie.cpp
@@ -111,22 +111,25 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
}
}
-PathTrie* PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timesteps) {
- return get_path_vec(output, timesteps, ROOT_);
+PathTrie* PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timesteps, std::vector<float>& log_probs) {
+ return get_path_vec(output, timesteps, log_probs, ROOT_);
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
std::vector<int>& timesteps,
+ std::vector<float>& log_probs,
int stop,
size_t max_steps) {
if (character == stop || character == ROOT_ || output.size() == max_steps) {
std::reverse(output.begin(), output.end());
std::reverse(timesteps.begin(), timesteps.end());
+ std::reverse(log_probs.begin(), log_probs.end());
return this;
} else {
output.push_back(character);
timesteps.push_back(timestep);
- return parent->get_path_vec(output, timesteps, stop, max_steps);
+ log_probs.push_back(log_prob_c);
+ return parent->get_path_vec(output, timesteps, log_probs, stop, max_steps);
}
}
diff --git a/native_client/ctcdecode/path_trie.h b/native_client/ctcdecode/path_trie.h
index 9b71f35b..7d4ae9fc 100644
--- a/native_client/ctcdecode/path_trie.h
+++ b/native_client/ctcdecode/path_trie.h
@@ -27,11 +27,12 @@ public:
PathTrie* get_path_trie(int new_char, int new_timestep, float log_prob_c, bool reset = true);
// get the prefix in index from root to current node
- PathTrie* get_path_vec(std::vector<int>& output, std::vector<int>& timesteps);
+ PathTrie* get_path_vec(std::vector<int>& output, std::vector<int>& timesteps, std::vector<float>& log_probs);
// get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(std::vector<int>& output,
std::vector<int>& timesteps,
+ std::vector<float>& log_probs,
int stop,
size_t max_steps = std::numeric_limits<size_t>::max());
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment