-
-
Save reuben/ee38a3608c3542e6515c6b9afdb5d670 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/native_client/ctcdecode/path_trie.cpp b/native_client/ctcdecode/path_trie.cpp | |
index 51f75ff3..69f8b6f7 100644 | |
--- a/native_client/ctcdecode/path_trie.cpp | |
+++ b/native_client/ctcdecode/path_trie.cpp | |
@@ -111,22 +111,25 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_ | |
} | |
} | |
-PathTrie* PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timesteps) { | |
- return get_path_vec(output, timesteps, ROOT_); | |
+PathTrie* PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timesteps, std::vector<float>& log_probs) { | |
+ return get_path_vec(output, timesteps, log_probs, ROOT_); | |
} | |
PathTrie* PathTrie::get_path_vec(std::vector<int>& output, | |
std::vector<int>& timesteps, | |
+ std::vector<float>& log_probs, | |
int stop, | |
size_t max_steps) { | |
if (character == stop || character == ROOT_ || output.size() == max_steps) { | |
std::reverse(output.begin(), output.end()); | |
std::reverse(timesteps.begin(), timesteps.end()); | |
+ std::reverse(log_probs.begin(), log_probs.end()); | |
return this; | |
} else { | |
output.push_back(character); | |
timesteps.push_back(timestep); | |
- return parent->get_path_vec(output, timesteps, stop, max_steps); | |
+ log_probs.push_back(log_prob_c); | |
+ return parent->get_path_vec(output, timesteps, log_probs, stop, max_steps); | |
} | |
} | |
diff --git a/native_client/ctcdecode/path_trie.h b/native_client/ctcdecode/path_trie.h | |
index 9b71f35b..7d4ae9fc 100644 | |
--- a/native_client/ctcdecode/path_trie.h | |
+++ b/native_client/ctcdecode/path_trie.h | |
@@ -27,11 +27,12 @@ public: | |
PathTrie* get_path_trie(int new_char, int new_timestep, float log_prob_c, bool reset = true); | |
// get the prefix in index from root to current node | |
- PathTrie* get_path_vec(std::vector<int>& output, std::vector<int>& timesteps); | |
+ PathTrie* get_path_vec(std::vector<int>& output, std::vector<int>& timesteps, std::vector<float>& log_probs); | |
// get the prefix in index from some stop node to current nodel | |
PathTrie* get_path_vec(std::vector<int>& output, | |
std::vector<int>& timesteps, | |
+ std::vector<float>& log_probs, | |
int stop, | |
size_t max_steps = std::numeric_limits<size_t>::max()); | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment