Skip to content

Instantly share code, notes, and snippets.

@DoodleBears
Last active June 28, 2024 18:32
Show Gist options
  • Save DoodleBears/e266c756e70a91cd42b4a5757d7ff2c2 to your computer and use it in GitHub Desktop.
Save DoodleBears/e266c756e70a91cd42b4a5757d7ff2c2 to your computer and use it in GitHub Desktop.
from typing import List
from langdetect import detect
import fast_langdetect
from wtpsplit import WtP
wtp = WtP("wtp-bert-mini")
def detect_lang(text: str) -> str:
result = str(detect(text))
result = result.lower()
return result
def fast_detect_lang(text: str, text_len_threshold = 3) -> str:
"""
Chinese, Japanese, Korean text len <= 3 has bad result on `fasttext` compare to `langdetect`
However, Chinese usually been detect as Korean on `langdetect`,
so double check with `fasttext` is needed (both detect korean, then it is korean)
"""
if len(text) <= text_len_threshold:
return detect_lang(text)
result = str(fast_langdetect.detect(text, low_memory=False)["lang"])
result = result.lower()
return result
lang_map = {
"zh": "zh",
"zh-cn": "zh",
"zh-tw": "x",
"ko": "ko",
"ja": "ja",
}
def adjust_middle_block(concat_result):
""" "
if middle language is not defined ('x'), and both side of middle have same lang, merge lang
"""
for index in range(len(concat_result) - 2):
left_block = concat_result[index]
middle_block = concat_result[index + 1]
right_block = concat_result[index + 2]
if left_block[0] == right_block[0] and left_block[0] != "x":
if len(middle_block[1]) <= 1 or middle_block[0] == "x":
concat_result[index + 1][0] = left_block[0]
def adjust_side_block(concat_result):
""" "
if side block is too short (len=1) or not defined ('x'), find nearest defined lang to set
"""
if concat_result[0][0] == "x":
for right_n in range(len(concat_result)):
if concat_result[right_n][0] != "x":
concat_result[0][0] = concat_result[right_n][0]
break
elif len(concat_result[0][1]) <= 1:
concat_result[0][0] = find_nearest_lang_with_direction(
concat_result, 0, is_left=False
) # search right
if concat_result[-1][0] == "x":
concat_result[-1][0] = find_nearest_lang_with_direction(
concat_result, len(concat_result) - 1, is_left=True
)
def fill_missing_languages(concat_result):
""" "
if language is not defined ('x'), merge based on direction,
THIS STEP SHOULD EXECUTED AFTER `adjust_middle_block` for better result
"""
for index in range(len(concat_result)):
if concat_result[index][0] == "x":
if index == 0 or index == len(concat_result) - 1:
concat_result[index][0] = find_nearest_lang(concat_result, index)
else:
is_left = decide_direction(concat_result, index)
concat_result[index][0] = find_nearest_lang_with_direction(
concat_result, index, is_left
)
def find_nearest_lang_with_direction(concat_result, index, is_left):
if is_left:
for i in range(1, len(concat_result)):
if index - i >= 0 and concat_result[index - i][0] != "x":
return concat_result[index - i][0]
else:
for i in range(1, len(concat_result)):
if index + i < len(concat_result) and concat_result[index + i][0] != "x":
return concat_result[index + i][0]
return "en"
def find_nearest_lang(concat_result, index):
for i in range(1, len(concat_result)):
if index - i >= 0 and concat_result[index - i][0] != "x":
return concat_result[index - i][0]
if index + i < len(concat_result) and concat_result[index + i][0] != "x":
return concat_result[index + i][0]
return "en"
def decide_direction(concat_result, index):
"""
Based on test result:
merge to direction which substring is shorter will be better
also, if one side is not Japanese or Chinese, and the other side is, merge to the JA, ZH side
"""
is_left = False
if index == 0:
is_left = False
return is_left
elif index == len(concat_result) - 1:
is_left = True
return is_left
left_block = concat_result[index - 1]
right_block = concat_result[index + 1]
if len(left_block[1]) < len(right_block[1]) or right_block[0] not in ["ja", "zh"]:
is_left = True
else:
is_left = False
return is_left
def merge_blocks(concat_result):
smart_concat_result = []
lang = ""
for block in concat_result:
cur_lang = block[0]
if cur_lang != lang:
smart_concat_result.append(block)
else:
smart_concat_result[-1][1] += block[1]
lang = cur_lang
return smart_concat_result
# MARK: check_languages
def check_languages(smart_concat_result):
for index, block in enumerate(smart_concat_result):
try:
cur_lang = fast_detect_lang(block[1])
except Exception:
cur_lang = "en"
cur_lang = lang_map.get(cur_lang, "en")
if cur_lang == "ko":
"""
Chinese, Japanese, Korean text len <= 3 has bad result on `fasttext` compare to `langdetect`
However, Chinese usually been detect as Korean on `langdetect`,
so double check with `fasttext` is needed (both detect korean, then it is korean)
"""
fast_lang = fast_detect_lang(block[1], text_len_threshold=0)
if fast_lang != "ko":
is_left = decide_direction(smart_concat_result, index)
cur_lang = find_nearest_lang_with_direction(
smart_concat_result, index, is_left
)
if cur_lang != "x":
block[0] = cur_lang
def smart_concat_240629(concat_result):
# combine short substring first
adjust_middle_block(concat_result)
concat_result = merge_blocks(concat_result)
check_languages(concat_result)
adjust_middle_block(concat_result)
fill_missing_languages(concat_result)
adjust_side_block(concat_result)
concat_result = merge_blocks(concat_result)
check_languages(concat_result)
return concat_result
# MARK: smart_concat
def smart_concat(concat_result):
is_concat_complete = False
while is_concat_complete is False:
concat_result = smart_concat_240629(concat_result)
is_concat_complete = True
for index, block in enumerate(concat_result):
if block[0] == "x":
is_concat_complete = False
break
if index < len(concat_result) - 1:
if concat_result[index][0] == concat_result[index + 1][0]:
is_concat_complete = False
break
return concat_result
def init_substr_lang(substr: List[str]):
concat_result = []
lang = ""
for block in substr:
try:
cur_lang = detect_lang(block)
except:
cur_lang = lang # punc has no lang, so concat with previous sub string
cur_lang = lang_map.get(cur_lang, "en")
concat_result.append([cur_lang, block])
lang = cur_lang
return concat_result
def lang_concat(text: str, substr: List[str], smart_combine=True):
"""concat over-split substring based on their language
Args:
text (str): original text
substr (List[str]): splitted substring
smart_combine (bool, optional): smart combine combine the substring if it is two short(1 char) and substring at both sides of it are same language. Defaults to True.
"""
concat_result = init_substr_lang(substr)
concat_result = smart_concat(concat_result)
for index, block in enumerate(concat_result):
print(f"{block[0]}|{index}: {block[1]}")
print(f"------------")
# MARK: main
texts = [
"我是 VGroupChatBot,一个旨在支持多人通信的助手,通过可视化消息来帮助团队成员更好地交流。我可以帮助团队成员更好地整理和共享信息,特别是在讨论、会议和Brainstorming等情况下。你好我的名字是西野くまですmy name is bob很高兴认识你どうぞよろしくお願いいたします「こんにちは」是什么意思。",
"你好,我的名字是西野くまです。I am from Tokyo, 日本の首都。今天的天气非常好,sky is clear and sunny。おはようございます、皆さん!我们一起来学习吧。Learning languages can be fun and exciting。昨日はとても忙しかったので、今日は少しリラックスしたいです。Let's take a break and enjoy some coffee。中文、日本語、and English are three distinct languages, each with its own unique charm。希望我们能一起进步,一起成长。Let's keep studying and improving our language skills together. ありがとう!",
"你好,今日はどこへ行きますか?",
"我的名字是田中さんです。",
"我喜欢吃寿司和拉面、おいしいです。",
"我喜欢吃寿司和拉面おいしいです。",
"今天の天気はとてもいいですね。",
"我在学习日本語、少し難しいです。",
"我在学习日本語少し難しいです。",
"日语真是おもしろい啊",
"你喜欢看アニメ吗?",
"我想去日本旅行、特に京都に行きたいです。",
"昨天見た映画はとても感動的でした。" "我朋友是日本人、彼はとても優しいです。",
"我们一起去カラオケ吧、楽しそうです。",
"你今天吃了什么、朝ごはんは何ですか?",
"我的家在北京、でも、仕事で東京に住んでいます。",
"我喜欢读书、本を読むのが好きです。",
"这个周末、一緒に公園へ行きましょうか?",
"你的猫很可爱、あなたの猫はかわいいです。",
"我在学做日本料理、日本料理を作るのを習っています。",
"你会说几种语言、何ヶ国語話せますか?",
"我昨天看了一本书、その本はとても面白かったです。",
"我们一起去逛街、買い物に行きましょう。",
"你最近好吗、最近どうですか?",
"我在学做日本料理와 한국 요리、日本料理を作るのを習っています。",
"你会说几种语言、何ヶ国語話せますか?몇 개 언어를 할 수 있어요?",
"我昨天看了一本书、その本はとても面白かったです。어제 책을 읽었는데, 정말 재미있었어요。",
"我们一起去逛街와 쇼핑、買い物に行きましょう。쇼핑하러 가요。",
"你最近好吗、最近どうですか?요즘 어떻게 지내요?",
]
for text in texts:
substr = wtp.split(text, threshold=5e-5)
lang_concat(text, substr, True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment