Last active
May 30, 2024 03:15
-
-
Save luistung/4f23b7d0026b26560fdd82a3b39ca460 to your computer and use it in GitHub Desktop.
c++ version of bert tokenize
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* c++ version of tokenization for bert | |
Copyright (C) 2019 luistung | |
This program is free software: you can redistribute it and/or modify | |
it under the terms of the GNU General Public License as published by | |
the Free Software Foundation, either version 3 of the License, or | |
(at your option) any later version. | |
This program is distributed in the hope that it will be useful, | |
but WITHOUT ANY WARRANTY; without even the implied warranty of | |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
GNU General Public License for more details. | |
You should have received a copy of the GNU General Public License | |
along with this program. If not, see <http://www.gnu.org/licenses/>.*/ | |
#include <iostream> | |
#include <fstream> | |
#include <string> | |
#include <vector> | |
#include <unordered_map> | |
#include <boost/algorithm/string.hpp> | |
#include <utf8proc.h> | |
//https://unicode.org/reports/tr15/#Norm_Forms | |
//https://ssl.icu-project.org/apiref/icu4c/uchar_8h.html | |
const std::wstring stripChar = L" \t\n\r\v\f"; | |
using Vocab = std::unordered_map<std::wstring, size_t>; | |
using InvVocab = std::unordered_map<size_t, std::wstring>; | |
class BasicTokenizer { | |
public: | |
BasicTokenizer(bool doLowerCase=true); | |
std::vector<std::wstring> tokenize(const std::string& text) const; | |
private: | |
std::wstring cleanText(const std::wstring& text) const; | |
bool isControol(const wchar_t& ch) const; | |
bool isWhitespace(const wchar_t& ch) const; | |
bool isPunctuation(const wchar_t& ch) const; | |
bool isChineseChar(const wchar_t& ch) const; | |
std::wstring tokenizeChineseChars(const std::wstring& text) const; | |
bool isStripChar(const wchar_t& ch) const; | |
std::wstring strip(const std::wstring& text) const; | |
std::vector<std::wstring> split(const std::wstring& text) const; | |
std::wstring runStripAccents(const std::wstring& text) const; | |
std::vector<std::wstring> runSplitOnPunc(const std::wstring& text) const; | |
bool mDoLowerCase; | |
}; | |
class WordpieceTokenizer { | |
public: | |
WordpieceTokenizer(std::shared_ptr<Vocab> vocab, const std::wstring& unkToken = L"[UNK]", size_t maxInputCharsPerWord=200); | |
std::vector<std::wstring> tokenize(const std::wstring& text) const; | |
private: | |
std::shared_ptr<Vocab> mVocab; | |
std::wstring mUnkToken; | |
size_t mMaxInputCharsPerWord; | |
}; | |
class FullTokenizer { | |
public: | |
FullTokenizer(const std::string& vocabFile, bool doLowerCase = true); | |
std::vector<std::wstring> tokenize(const std::string& text) const; | |
std::vector<size_t> convertTokensToIds(const std::vector<std::wstring>& text) const; | |
private: | |
std::shared_ptr<Vocab> mVocab; | |
InvVocab mInvVocab; | |
std::string mVocabFile; | |
bool mDoLowerCase; | |
BasicTokenizer mBasicTokenizer; | |
WordpieceTokenizer mWordpieceTokenizer; | |
}; | |
static std::string normalize_nfd(const std::string& s) { | |
std::string ret; | |
char *result = (char *) utf8proc_NFD((unsigned char *)s.c_str()); | |
if (result) { | |
ret = std::string(result); | |
free(result); | |
result = NULL; | |
} | |
return ret; | |
} | |
static bool isStripChar(const wchar_t& ch) { | |
return stripChar.find(ch) != std::wstring::npos; | |
} | |
static std::wstring strip(const std::wstring& text) { | |
std::wstring ret = text; | |
if (ret.empty()) return ret; | |
size_t pos = 0; | |
while (pos < ret.size() && isStripChar(ret[pos])) pos++; | |
if (pos != 0) ret = ret.substr(pos, ret.size() - pos); | |
pos = ret.size() - 1; | |
while (pos != (size_t)-1 && isStripChar(ret[pos])) pos--; | |
return ret.substr(0, pos + 1); | |
} | |
static std::vector<std::wstring> split(const std::wstring& text) { | |
std::vector<std::wstring> result; | |
boost::split(result, text, boost::is_any_of(stripChar)); | |
return result; | |
} | |
static std::vector<std::wstring> whitespaceTokenize(const std::wstring& text) { | |
std::wstring rtext = strip(text); | |
if (rtext.empty()) return std::vector<std::wstring>(); | |
return split(text); | |
} | |
static std::wstring convertToUnicode(const std::string& text) { | |
size_t i = 0; | |
std::wstring ret; | |
while (i < text.size()) { | |
wchar_t codepoint; | |
utf8proc_ssize_t forward = utf8proc_iterate((utf8proc_uint8_t *)&text[i], text.size() - i, (utf8proc_int32_t*)&codepoint); | |
if (forward < 0) return L""; | |
ret += codepoint; | |
i += forward; | |
} | |
return ret; | |
} | |
static std::string convertFromUnicode(const std::wstring& wText) { | |
char dst[64]; | |
std::string ret; | |
for (auto ch : wText) { | |
utf8proc_ssize_t num = utf8proc_encode_char(ch, (utf8proc_uint8_t *)dst); | |
if (num <= 0) return ""; | |
ret += std::string(dst, dst+num); | |
} | |
return ret; | |
} | |
static std::wstring tolower(const std::wstring& s) { | |
std::wstring ret(s.size(), L' '); | |
for (size_t i = 0; i < s.size(); i++) { | |
ret[i] = utf8proc_tolower(s[i]); | |
} | |
return ret; | |
} | |
static std::shared_ptr<Vocab> loadVocab(const std::string& vocabFile) { | |
std::shared_ptr<Vocab> vocab(new Vocab); | |
size_t index = 0; | |
std::ifstream ifs(vocabFile, std::ifstream::in); | |
if (!ifs) { | |
throw std::runtime_error("open file failed"); | |
} | |
std::string line; | |
while (getline(ifs, line)) { | |
std::wstring token = convertToUnicode(line); | |
if (token.empty()) break; | |
token = strip(token); | |
(*vocab)[token] = index; | |
index++; | |
} | |
return vocab; | |
} | |
BasicTokenizer::BasicTokenizer(bool doLowerCase) | |
: mDoLowerCase(doLowerCase) { | |
} | |
std::wstring BasicTokenizer::cleanText(const std::wstring& text) const { | |
std::wstring output; | |
for (const wchar_t& cp : text) { | |
if (cp == 0 || cp == 0xfffd || isControol(cp)) continue; | |
if (isWhitespace(cp)) output += L" "; | |
else output += cp; | |
} | |
return output; | |
} | |
bool BasicTokenizer::isControol(const wchar_t& ch) const { | |
if (ch== L'\t' || ch== L'\n' || ch== L'\r') return false; | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_CC || cat == UTF8PROC_CATEGORY_CF) return true; | |
return false; | |
} | |
bool BasicTokenizer::isWhitespace(const wchar_t& ch) const { | |
if (ch== L' ' || ch== L'\t' || ch== L'\n' || ch== L'\r') return true; | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_ZS) return true; | |
return false; | |
} | |
bool BasicTokenizer::isPunctuation(const wchar_t& ch) const { | |
if ((ch >= 33 && ch <= 47) || (ch >= 58 && ch <= 64) || | |
(ch >= 91 && ch <= 96) || (ch >= 123 && ch <= 126)) return true; | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_PD || cat == UTF8PROC_CATEGORY_PS | |
|| cat == UTF8PROC_CATEGORY_PE || cat == UTF8PROC_CATEGORY_PC | |
|| cat == UTF8PROC_CATEGORY_PO //sometimes ¶ belong SO | |
|| cat == UTF8PROC_CATEGORY_PI | |
|| cat == UTF8PROC_CATEGORY_PF) return true; | |
return false; | |
} | |
bool BasicTokenizer::isChineseChar(const wchar_t& ch) const { | |
if ((ch >= 0x4E00 && ch <= 0x9FFF) || | |
(ch >= 0x3400 && ch <= 0x4DBF) || | |
(ch >= 0x20000 && ch <= 0x2A6DF) || | |
(ch >= 0x2A700 && ch <= 0x2B73F) || | |
(ch >= 0x2B740 && ch <= 0x2B81F) || | |
(ch >= 0x2B820 && ch <= 0x2CEAF) || | |
(ch >= 0xF900 && ch <= 0xFAFF) || | |
(ch >= 0x2F800 && ch <= 0x2FA1F)) | |
return true; | |
return false; | |
} | |
std::wstring BasicTokenizer::tokenizeChineseChars(const std::wstring& text) const { | |
std::wstring output; | |
for (auto& ch : text) { | |
if (isChineseChar(ch)) { | |
output += L' '; | |
output += ch; | |
output += L' '; | |
} | |
else | |
output += ch; | |
} | |
return output; | |
} | |
std::wstring BasicTokenizer::runStripAccents(const std::wstring& text) const { | |
//Strips accents from a piece of text. | |
std::wstring nText; | |
try { | |
nText = convertToUnicode(normalize_nfd(convertFromUnicode(text))); | |
} catch (std::bad_cast& e) { | |
std::cerr << "bad_cast" << std::endl; | |
return L""; | |
} | |
std::wstring output; | |
for (auto& ch : nText) { | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_MN) continue; | |
output += ch; | |
} | |
return output; | |
} | |
std::vector<std::wstring> BasicTokenizer::runSplitOnPunc(const std::wstring& text) const { | |
size_t i = 0; | |
bool startNewWord = true; | |
std::vector<std::wstring> output; | |
while (i < text.size()) { | |
wchar_t ch = text[i]; | |
if (isPunctuation(ch)) { | |
output.push_back(std::wstring(&ch, 1)); | |
startNewWord = true; | |
} | |
else { | |
if (startNewWord) output.push_back(std::wstring()); | |
startNewWord = false; | |
output[output.size() - 1] += ch; | |
} | |
i++; | |
} | |
return output; | |
} | |
std::vector<std::wstring> BasicTokenizer::tokenize(const std::string& text) const { | |
std::wstring nText = convertToUnicode(text); | |
nText = cleanText(nText); | |
nText = tokenizeChineseChars(nText); | |
const std::vector<std::wstring>& origTokens = whitespaceTokenize(nText); | |
std::vector<std::wstring> splitTokens; | |
for (std::wstring token : origTokens) { | |
if (mDoLowerCase) { | |
token = tolower(token); | |
token = runStripAccents(token); | |
} | |
const auto& tokens = runSplitOnPunc(token); | |
splitTokens.insert(splitTokens.end(), tokens.begin(), tokens.end()); | |
} | |
return whitespaceTokenize(boost::join(splitTokens, L" ")); | |
} | |
WordpieceTokenizer::WordpieceTokenizer(const std::shared_ptr<Vocab> vocab, const std::wstring& unkToken, size_t maxInputCharsPerWord) | |
: mVocab(vocab), | |
mUnkToken(unkToken), | |
mMaxInputCharsPerWord(maxInputCharsPerWord) { | |
} | |
std::vector<std::wstring> WordpieceTokenizer::tokenize(const std::wstring& text) const { | |
std::vector<std::wstring> outputTokens; | |
for (auto& token : whitespaceTokenize(text)) { | |
if (token.size() > mMaxInputCharsPerWord) { | |
outputTokens.push_back(mUnkToken); | |
} | |
bool isBad = false; | |
size_t start = 0; | |
std::vector<std::wstring> subTokens; | |
while (start < token.size()) { | |
size_t end = token.size(); | |
std::wstring curSubstr; | |
bool hasCurSubstr = false; | |
while (start < end) { | |
std::wstring substr = token.substr(start, end - start); | |
if (start > 0) substr = L"##" + substr; | |
if (mVocab->find(substr) != mVocab->end()) { | |
curSubstr = substr; | |
hasCurSubstr = true; | |
break; | |
} | |
end--; | |
} | |
if (!hasCurSubstr) { | |
isBad = true; | |
break; | |
} | |
subTokens.push_back(curSubstr); | |
start = end; | |
} | |
if (isBad) outputTokens.push_back(mUnkToken); | |
else outputTokens.insert(outputTokens.end(), subTokens.begin(), subTokens.end()); | |
} | |
return outputTokens; | |
} | |
FullTokenizer::FullTokenizer(const std::string& vocabFile, bool doLowerCase) : | |
mVocab(loadVocab(vocabFile)), | |
mBasicTokenizer(BasicTokenizer(doLowerCase)), | |
mWordpieceTokenizer(WordpieceTokenizer(mVocab)) { | |
for (auto& v : *mVocab) mInvVocab[v.second] = v.first; | |
} | |
std::vector<std::wstring> FullTokenizer::tokenize(const std::string& text) const { | |
std::vector<std::wstring> splitTokens; | |
for (auto& token : mBasicTokenizer.tokenize(text)) | |
for (auto& subToken : mWordpieceTokenizer.tokenize(token)) | |
splitTokens.push_back(subToken); | |
return splitTokens; | |
} | |
std::vector<size_t> FullTokenizer::convertTokensToIds(const std::vector<std::wstring>& text) const { | |
std::vector<size_t> ret(text.size()); | |
for (size_t i = 0; i < text.size(); i++) { | |
ret[i] = (*mVocab)[text[i]]; | |
} | |
return ret; | |
} | |
int main() { | |
FullTokenizer* pTokenizer = nullptr; | |
try { | |
pTokenizer = new FullTokenizer("data/chinese_L-12_H-768_A-12/vocab.txt"); | |
} | |
catch (std::exception& e) { | |
std::cerr << "construct FullTokenizer failed" << std::endl; | |
return -1; | |
} | |
std::string line; | |
while (std::getline(std::cin, line)) { | |
auto tokens = pTokenizer->tokenize(line); | |
auto ids = pTokenizer->convertTokensToIds(tokens); | |
std::cout << "#" << convertFromUnicode(boost::join(tokens, L" ")) << "#" << "\t"; | |
for (size_t i = 0; i < ids.size(); i++) { | |
if (i!=0) std::cout << " "; | |
std::cout << ids[i]; | |
} | |
std::cout << std::endl; | |
} | |
return 0; | |
} |
have not compile on Windows. maybe you can try cmake.
see official document https://github.com/JuliaStrings/utf8proc?tab=readme-ov-file
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
why not C library?