Skip to content

Instantly share code, notes, and snippets.

@BankNatchapol
Last active May 16, 2024 09:21
Show Gist options
  • Save BankNatchapol/1276e34dcb51c521536978859dd948cd to your computer and use it in GitHub Desktop.
Save BankNatchapol/1276e34dcb51c521536978859dd948cd to your computer and use it in GitHub Desktop.
import re
import unicodedata
from transformers import AutoTokenizer
from . import punctuation, symbols, pu_symbols
from num2words import num2words
from pythainlp.tokenize import word_tokenize
from pythainlp.transliterate import romanize
from pythainlp.util import normalize as thai_normalize
from pythainlp.util import thai_to_eng, eng_to_thai
from melo.text.thai_dictionary import english_dictionary, etc_dictionary
from pythainlp.transliterate import transliterate
def normalize_with_dictionary(text, dic):
if any(key in text for key in dic.keys()):
pattern = re.compile("|".join(re.escape(key) for key in dic.keys()))
return pattern.sub(lambda x: dic[x.group()], text)
return text
def normalize(text):
text = text.strip()
text = thai_normalize(text)
text = normalize_with_dictionary(text, etc_dictionary)
text = re.sub(r"\d+", lambda x: num2words(int(x.group()), lang="th"), text)
text = normalize_english(text)
text = text.lower()
return text
def normalize_english(text):
def fn(m):
word = m.group()
if word.upper() in english_dictionary:
return english_dictionary[word.upper()]
return "".join(english_dictionary.get(char.upper(), char) for char in word)
text = re.sub(r"([A-Za-z]+)", fn, text)
return text
# Load the Thai G2P dictionary
thai_g2p_dict = {}
from os import path
file_path = path.abspath(__file__) # full path of your script
dir_path = path.dirname(file_path)
with open(path.join(dir_path, "wiktionary-23-7-2022-clean.tsv"), "r", encoding="utf-8") as f:
for line in f:
word, phonemes = line.strip().split("\t")
thai_g2p_dict[word] = phonemes.split()
def map_word_to_phonemes(word):
return thai_g2p_dict.get(word, list(word))
# def thai_text_to_phonemes(text):
# text = normalize(text)
# words = word_tokenize(text, engine="newmm")
# phonemes = []
# for word in words:
# word_phonemes = map_word_to_phonemes(word)
# phonemes.extend(word_phonemes)
# return " ".join(phonemes)
def thai_text_to_phonemes(text):
text = normalize(text)
words = word_tokenize(text, engine="newmm")
phonemes = []
for word in words:
word_phonemes = transliterate(word, engine="thaig2p")
phonemes.extend(word_phonemes)
return "".join(phonemes)
def text_normalize(text):
text = normalize(text)
return text
def distribute_phone(n_phone, n_word):
phones_per_word = [0] * n_word
for task in range(n_phone):
min_tasks = min(phones_per_word)
min_index = phones_per_word.index(min_tasks)
phones_per_word[min_index] += 1
return phones_per_word
model_id = 'airesearch/wangchanberta-base-att-spm-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_id)
import re
def refine_ph(phns):
tone_dict = ['˧', '˨˩', '˥˩', '˦˥', '˩˩˦']
ts = 0
tone = 0
for t in tone_dict:
ts += phns.count(t)
if ts == 0:
result = phns.lower().strip()
fi_result = [x for x in list(result) if x not in [' ', '']]
return fi_result, [tone]*len(fi_result)
if ts == 1:
for t in tone_dict:
if t in phns:
phns = phns.replace(t, '').strip()
tone = tone_dict.index(t)
result = phns.lower().strip()
fi_result = [x for x in list(result) if x not in [' ', '']]
return fi_result, [tone]*len(fi_result)
if ts>1:
tone_pos = []
for t in tone_dict:
if t in phns:
tone_pos.append((tone_dict.index(t), phns.index(t)))
tone_pos.sort(key=lambda x: x[1])
sp_phns = re.split('˧|˨˩|˥˩|˦˥|˩˩˦', phns)
sp_phns = [p.strip() for p in sp_phns if p.strip() != '']
# print(sp_phns, tone_pos)
# print(len(tone_pos)==len(sp_phns), f'Tone and phonemes size not matched. {tone_pos} {sp_phns}')
tones = []
fi_phs = []
for i, phn in enumerate(sp_phns):
aai = [x for x in list(phn) if x not in [' ', '']]
fi_phs += aai
if i>(len(tone_pos)-1):
tones += [0] * len(aai)
else:
tones += [tone_pos[i][0]] * len(aai)
return fi_phs, tones
def g2p(norm_text):
tokenized = tokenizer.tokenize(norm_text)
phs = []
ph_groups = []
current_group = [] # Track the current group of tokens
word2ph = []
tones = []
for t in tokenized:
if t in punctuation or t in pu_symbols: # Check if the token is a special character
phs.append(t)
word2ph.append(1)
else:
if t.startswith("▁"): # Start of a new word or phrase
if current_group: # Append current group to ph_groups if not empty
ph_groups.append(current_group)
current_group = [] # Reset current_group for the new word or phrase
current_group.append(t.replace("▁", "")) # Add token to current_group
if current_group: # Append the last group if not empty
ph_groups.append(current_group)
for group in ph_groups:
phone_len = 0
text = "".join(group) # Concatenate tokens in the group to form the word or phrase
if text == '[UNK]': # handle special cases like unknown tokens ("[UNK]")
phs.append('_')
word2ph.append(1)
continue
phonemes = thai_text_to_phonemes(text)
phone_list = [l.strip() for l in phonemes.strip().split('.') if l.strip() != '']
for ph in phone_list:
ph, tn = refine_ph(ph)
phs += ph
tones += tn
phone_len += len(ph)
word_len = len(group)
aaa = distribute_phone(phone_len, word_len)
assert len(aaa) == word_len, 'len(aaa) != word_len'
word2ph.extend(aaa)
# phs.extend(phonemes.split())
phones = ["_"] + phs + ["_"]
tones = [0] + tones + [0]
word2ph = [1] + word2ph + [1]
assert len(word2ph) == len(tokenized) + 2, 'len(word2ph) != len(tokenized) + 2'
return phones, tones, word2ph
def get_bert_feature(text, word2ph, device='cuda', model_id='airesearch/wangchanberta-base-att-spm-uncased'):
from . import thai_bert
return thai_bert.get_bert_feature(text, word2ph, device=device, model_id=model_id)
if __name__ == "__main__":
try:
from text.symbols import symbols
text = "ฉันเข้าใจคุณค่าของงานของฉันและความหมายของสิ่งที่ฟอนเทนทำเพื่อคนทั่วไปเป็นอย่างดี ฉันจะใช้ชีวิตอย่างภาคภูมิใจในงานของฉันต่อไป"
text = text_normalize(text)
phones, tones, word2ph = g2p(text)
bert = get_bert_feature(text, word2ph, device='cuda', model_id=model_id)
new_symbols = []
for ph in phones:
if ph not in symbols and ph not in new_symbols:
new_symbols.append(ph)
print('update!, now symbols:')
print(new_symbols)
with open('thai_symbol.txt', 'w') as f:
f.write(f'{new_symbols}')
except Exception as e:
print(f"An error occurred: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment