-
-
Save luistung/4f23b7d0026b26560fdd82a3b39ca460 to your computer and use it in GitHub Desktop.
/* c++ version of tokenization for bert | |
Copyright (C) 2019 luistung | |
This program is free software: you can redistribute it and/or modify | |
it under the terms of the GNU General Public License as published by | |
the Free Software Foundation, either version 3 of the License, or | |
(at your option) any later version. | |
This program is distributed in the hope that it will be useful, | |
but WITHOUT ANY WARRANTY; without even the implied warranty of | |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
GNU General Public License for more details. | |
You should have received a copy of the GNU General Public License | |
along with this program. If not, see <http://www.gnu.org/licenses/>.*/ | |
#include <iostream> | |
#include <fstream> | |
#include <string> | |
#include <vector> | |
#include <unordered_map> | |
#include <boost/algorithm/string.hpp> | |
#include <utf8proc.h> | |
//https://unicode.org/reports/tr15/#Norm_Forms | |
//https://ssl.icu-project.org/apiref/icu4c/uchar_8h.html | |
const std::wstring stripChar = L" \t\n\r\v\f"; | |
using Vocab = std::unordered_map<std::wstring, size_t>; | |
using InvVocab = std::unordered_map<size_t, std::wstring>; | |
class BasicTokenizer { | |
public: | |
BasicTokenizer(bool doLowerCase=true); | |
std::vector<std::wstring> tokenize(const std::string& text) const; | |
private: | |
std::wstring cleanText(const std::wstring& text) const; | |
bool isControol(const wchar_t& ch) const; | |
bool isWhitespace(const wchar_t& ch) const; | |
bool isPunctuation(const wchar_t& ch) const; | |
bool isChineseChar(const wchar_t& ch) const; | |
std::wstring tokenizeChineseChars(const std::wstring& text) const; | |
bool isStripChar(const wchar_t& ch) const; | |
std::wstring strip(const std::wstring& text) const; | |
std::vector<std::wstring> split(const std::wstring& text) const; | |
std::wstring runStripAccents(const std::wstring& text) const; | |
std::vector<std::wstring> runSplitOnPunc(const std::wstring& text) const; | |
bool mDoLowerCase; | |
}; | |
class WordpieceTokenizer { | |
public: | |
WordpieceTokenizer(std::shared_ptr<Vocab> vocab, const std::wstring& unkToken = L"[UNK]", size_t maxInputCharsPerWord=200); | |
std::vector<std::wstring> tokenize(const std::wstring& text) const; | |
private: | |
std::shared_ptr<Vocab> mVocab; | |
std::wstring mUnkToken; | |
size_t mMaxInputCharsPerWord; | |
}; | |
class FullTokenizer { | |
public: | |
FullTokenizer(const std::string& vocabFile, bool doLowerCase = true); | |
std::vector<std::wstring> tokenize(const std::string& text) const; | |
std::vector<size_t> convertTokensToIds(const std::vector<std::wstring>& text) const; | |
private: | |
std::shared_ptr<Vocab> mVocab; | |
InvVocab mInvVocab; | |
std::string mVocabFile; | |
bool mDoLowerCase; | |
BasicTokenizer mBasicTokenizer; | |
WordpieceTokenizer mWordpieceTokenizer; | |
}; | |
static std::string normalize_nfd(const std::string& s) { | |
std::string ret; | |
char *result = (char *) utf8proc_NFD((unsigned char *)s.c_str()); | |
if (result) { | |
ret = std::string(result); | |
free(result); | |
result = NULL; | |
} | |
return ret; | |
} | |
static bool isStripChar(const wchar_t& ch) { | |
return stripChar.find(ch) != std::wstring::npos; | |
} | |
static std::wstring strip(const std::wstring& text) { | |
std::wstring ret = text; | |
if (ret.empty()) return ret; | |
size_t pos = 0; | |
while (pos < ret.size() && isStripChar(ret[pos])) pos++; | |
if (pos != 0) ret = ret.substr(pos, ret.size() - pos); | |
pos = ret.size() - 1; | |
while (pos != (size_t)-1 && isStripChar(ret[pos])) pos--; | |
return ret.substr(0, pos + 1); | |
} | |
static std::vector<std::wstring> split(const std::wstring& text) { | |
std::vector<std::wstring> result; | |
boost::split(result, text, boost::is_any_of(stripChar)); | |
return result; | |
} | |
static std::vector<std::wstring> whitespaceTokenize(const std::wstring& text) { | |
std::wstring rtext = strip(text); | |
if (rtext.empty()) return std::vector<std::wstring>(); | |
return split(text); | |
} | |
static std::wstring convertToUnicode(const std::string& text) { | |
size_t i = 0; | |
std::wstring ret; | |
while (i < text.size()) { | |
wchar_t codepoint; | |
utf8proc_ssize_t forward = utf8proc_iterate((utf8proc_uint8_t *)&text[i], text.size() - i, (utf8proc_int32_t*)&codepoint); | |
if (forward < 0) return L""; | |
ret += codepoint; | |
i += forward; | |
} | |
return ret; | |
} | |
static std::string convertFromUnicode(const std::wstring& wText) { | |
char dst[64]; | |
std::string ret; | |
for (auto ch : wText) { | |
utf8proc_ssize_t num = utf8proc_encode_char(ch, (utf8proc_uint8_t *)dst); | |
if (num <= 0) return ""; | |
ret += std::string(dst, dst+num); | |
} | |
return ret; | |
} | |
static std::wstring tolower(const std::wstring& s) { | |
std::wstring ret(s.size(), L' '); | |
for (size_t i = 0; i < s.size(); i++) { | |
ret[i] = utf8proc_tolower(s[i]); | |
} | |
return ret; | |
} | |
static std::shared_ptr<Vocab> loadVocab(const std::string& vocabFile) { | |
std::shared_ptr<Vocab> vocab(new Vocab); | |
size_t index = 0; | |
std::ifstream ifs(vocabFile, std::ifstream::in); | |
if (!ifs) { | |
throw std::runtime_error("open file failed"); | |
} | |
std::string line; | |
while (getline(ifs, line)) { | |
std::wstring token = convertToUnicode(line); | |
if (token.empty()) break; | |
token = strip(token); | |
(*vocab)[token] = index; | |
index++; | |
} | |
return vocab; | |
} | |
BasicTokenizer::BasicTokenizer(bool doLowerCase) | |
: mDoLowerCase(doLowerCase) { | |
} | |
std::wstring BasicTokenizer::cleanText(const std::wstring& text) const { | |
std::wstring output; | |
for (const wchar_t& cp : text) { | |
if (cp == 0 || cp == 0xfffd || isControol(cp)) continue; | |
if (isWhitespace(cp)) output += L" "; | |
else output += cp; | |
} | |
return output; | |
} | |
bool BasicTokenizer::isControol(const wchar_t& ch) const { | |
if (ch== L'\t' || ch== L'\n' || ch== L'\r') return false; | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_CC || cat == UTF8PROC_CATEGORY_CF) return true; | |
return false; | |
} | |
bool BasicTokenizer::isWhitespace(const wchar_t& ch) const { | |
if (ch== L' ' || ch== L'\t' || ch== L'\n' || ch== L'\r') return true; | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_ZS) return true; | |
return false; | |
} | |
bool BasicTokenizer::isPunctuation(const wchar_t& ch) const { | |
if ((ch >= 33 && ch <= 47) || (ch >= 58 && ch <= 64) || | |
(ch >= 91 && ch <= 96) || (ch >= 123 && ch <= 126)) return true; | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_PD || cat == UTF8PROC_CATEGORY_PS | |
|| cat == UTF8PROC_CATEGORY_PE || cat == UTF8PROC_CATEGORY_PC | |
|| cat == UTF8PROC_CATEGORY_PO //sometimes ¶ belong SO | |
|| cat == UTF8PROC_CATEGORY_PI | |
|| cat == UTF8PROC_CATEGORY_PF) return true; | |
return false; | |
} | |
bool BasicTokenizer::isChineseChar(const wchar_t& ch) const { | |
if ((ch >= 0x4E00 && ch <= 0x9FFF) || | |
(ch >= 0x3400 && ch <= 0x4DBF) || | |
(ch >= 0x20000 && ch <= 0x2A6DF) || | |
(ch >= 0x2A700 && ch <= 0x2B73F) || | |
(ch >= 0x2B740 && ch <= 0x2B81F) || | |
(ch >= 0x2B820 && ch <= 0x2CEAF) || | |
(ch >= 0xF900 && ch <= 0xFAFF) || | |
(ch >= 0x2F800 && ch <= 0x2FA1F)) | |
return true; | |
return false; | |
} | |
std::wstring BasicTokenizer::tokenizeChineseChars(const std::wstring& text) const { | |
std::wstring output; | |
for (auto& ch : text) { | |
if (isChineseChar(ch)) { | |
output += L' '; | |
output += ch; | |
output += L' '; | |
} | |
else | |
output += ch; | |
} | |
return output; | |
} | |
std::wstring BasicTokenizer::runStripAccents(const std::wstring& text) const { | |
//Strips accents from a piece of text. | |
std::wstring nText; | |
try { | |
nText = convertToUnicode(normalize_nfd(convertFromUnicode(text))); | |
} catch (std::bad_cast& e) { | |
std::cerr << "bad_cast" << std::endl; | |
return L""; | |
} | |
std::wstring output; | |
for (auto& ch : nText) { | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_MN) continue; | |
output += ch; | |
} | |
return output; | |
} | |
std::vector<std::wstring> BasicTokenizer::runSplitOnPunc(const std::wstring& text) const { | |
size_t i = 0; | |
bool startNewWord = true; | |
std::vector<std::wstring> output; | |
while (i < text.size()) { | |
wchar_t ch = text[i]; | |
if (isPunctuation(ch)) { | |
output.push_back(std::wstring(&ch, 1)); | |
startNewWord = true; | |
} | |
else { | |
if (startNewWord) output.push_back(std::wstring()); | |
startNewWord = false; | |
output[output.size() - 1] += ch; | |
} | |
i++; | |
} | |
return output; | |
} | |
std::vector<std::wstring> BasicTokenizer::tokenize(const std::string& text) const { | |
std::wstring nText = convertToUnicode(text); | |
nText = cleanText(nText); | |
nText = tokenizeChineseChars(nText); | |
const std::vector<std::wstring>& origTokens = whitespaceTokenize(nText); | |
std::vector<std::wstring> splitTokens; | |
for (std::wstring token : origTokens) { | |
if (mDoLowerCase) { | |
token = tolower(token); | |
token = runStripAccents(token); | |
} | |
const auto& tokens = runSplitOnPunc(token); | |
splitTokens.insert(splitTokens.end(), tokens.begin(), tokens.end()); | |
} | |
return whitespaceTokenize(boost::join(splitTokens, L" ")); | |
} | |
WordpieceTokenizer::WordpieceTokenizer(const std::shared_ptr<Vocab> vocab, const std::wstring& unkToken, size_t maxInputCharsPerWord) | |
: mVocab(vocab), | |
mUnkToken(unkToken), | |
mMaxInputCharsPerWord(maxInputCharsPerWord) { | |
} | |
std::vector<std::wstring> WordpieceTokenizer::tokenize(const std::wstring& text) const { | |
std::vector<std::wstring> outputTokens; | |
for (auto& token : whitespaceTokenize(text)) { | |
if (token.size() > mMaxInputCharsPerWord) { | |
outputTokens.push_back(mUnkToken); | |
} | |
bool isBad = false; | |
size_t start = 0; | |
std::vector<std::wstring> subTokens; | |
while (start < token.size()) { | |
size_t end = token.size(); | |
std::wstring curSubstr; | |
bool hasCurSubstr = false; | |
while (start < end) { | |
std::wstring substr = token.substr(start, end - start); | |
if (start > 0) substr = L"##" + substr; | |
if (mVocab->find(substr) != mVocab->end()) { | |
curSubstr = substr; | |
hasCurSubstr = true; | |
break; | |
} | |
end--; | |
} | |
if (!hasCurSubstr) { | |
isBad = true; | |
break; | |
} | |
subTokens.push_back(curSubstr); | |
start = end; | |
} | |
if (isBad) outputTokens.push_back(mUnkToken); | |
else outputTokens.insert(outputTokens.end(), subTokens.begin(), subTokens.end()); | |
} | |
return outputTokens; | |
} | |
FullTokenizer::FullTokenizer(const std::string& vocabFile, bool doLowerCase) : | |
mVocab(loadVocab(vocabFile)), | |
mBasicTokenizer(BasicTokenizer(doLowerCase)), | |
mWordpieceTokenizer(WordpieceTokenizer(mVocab)) { | |
for (auto& v : *mVocab) mInvVocab[v.second] = v.first; | |
} | |
std::vector<std::wstring> FullTokenizer::tokenize(const std::string& text) const { | |
std::vector<std::wstring> splitTokens; | |
for (auto& token : mBasicTokenizer.tokenize(text)) | |
for (auto& subToken : mWordpieceTokenizer.tokenize(token)) | |
splitTokens.push_back(subToken); | |
return splitTokens; | |
} | |
std::vector<size_t> FullTokenizer::convertTokensToIds(const std::vector<std::wstring>& text) const { | |
std::vector<size_t> ret(text.size()); | |
for (size_t i = 0; i < text.size(); i++) { | |
ret[i] = (*mVocab)[text[i]]; | |
} | |
return ret; | |
} | |
int main() { | |
FullTokenizer* pTokenizer = nullptr; | |
try { | |
pTokenizer = new FullTokenizer("data/chinese_L-12_H-768_A-12/vocab.txt"); | |
} | |
catch (std::exception& e) { | |
std::cerr << "construct FullTokenizer failed" << std::endl; | |
return -1; | |
} | |
std::string line; | |
while (std::getline(std::cin, line)) { | |
auto tokens = pTokenizer->tokenize(line); | |
auto ids = pTokenizer->convertTokensToIds(tokens); | |
std::cout << "#" << convertFromUnicode(boost::join(tokens, L" ")) << "#" << "\t"; | |
for (size_t i = 0; i < ids.size(); i++) { | |
if (i!=0) std::cout << " "; | |
std::cout << ids[i]; | |
} | |
std::cout << std::endl; | |
} | |
return 0; | |
} |
OMG, dude, you have to indicate which license you attach to this piece of code.
All right. I've done.
hi你好 ,编译运行显示
i am chinese a in beijing
i am chinese a in beijing
#[UNK] [UNK] [UNK] [UNK] [UNK] [UNK]# 0 0 0 0 0 0
wo 我是中国人
w我是中国人
#[UNK] [UNK] [UNK] [UNK] [UNK] [UNK]# 0 0 0 0 0 0
hi你好 ,编译运行显示
i am chinese a in beijing
i am chinese a in beijing
#[UNK] [UNK] [UNK] [UNK] [UNK] [UNK]# 0 0 0 0 0 0
wo 我是中国人
w我是中国人
#[UNK] [UNK] [UNK] [UNK] [UNK] [UNK]# 0 0 0 0 0 0
Probably, you gave a wrong vocab.txt path.
Just now, I've added exception when failing to open file.
hello,utf8proc is C library,how I get C++ library?
why not C library?
have not compile on Windows. maybe you can try cmake.
see official document https://github.com/JuliaStrings/utf8proc?tab=readme-ov-file
OMG, dude, you have to indicate which license you attach to this piece of code.