Skip to content

Instantly share code, notes, and snippets.

@euphoris
Last active October 6, 2023 03:17
Show Gist options
  • Save euphoris/220bdfde44d6fcdd78375877370a39f6 to your computer and use it in GitHub Desktop.
Save euphoris/220bdfde44d6fcdd78375877370a39f6 to your computer and use it in GitHub Desktop.
import re
class TokenAlignPreprocessor:
def __init__(self, tokenizer, pre_tokenizer, outside_label_id):
self.tokenizer = tokenizer
self.pre_tokenizer = pre_tokenizer
self.outside_label_id = outside_label_id
def align_label(self, word, word_tokens, char_labels):
i = j = 0
token_labels = []
while i < len(word) and j < len(word_tokens):
step = len(word_tokens[j].replace('##', ''))
token_labels.append(min(char_labels[i:i+step]))
i += step
j += 1
return token_labels
def convert_example(self, example):
tokens = ['[CLS]']
labels = [self.outside_label_id]
text = ''.join(example['tokens'])
pretokens = self.pre_tokenizer.pre_tokenize_str(text)
for word, (begin, end) in pretokens:
word_tokens = self.tokenizer.tokenize(word)
if '[UNK]' in word_tokens:
token_labels = [self.outside_label_id] * len(word_tokens)
else:
char_labels = example['ner_tags'][begin:end]
token_labels = self.align_label(word, word_tokens, char_labels)
tokens += word_tokens
labels += token_labels
tokens.append('[SEP]')
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
n = len(input_ids)
token_type_ids = [0] * n
attention_mask = [1] * n
labels.append(self.outside_label_id)
return {'input_ids': input_ids,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
'labels': labels}
def ner_tokenize(sentence, label2id):
"""<LC>문막휴게소</LC>와 같이 태깅된 데이터를 KLUE NER 데이터셋과 같은 형태로 토큰화한다"""
tag_list = '|'.join({
label[2:] for label in label2id if label != 0
})
start_tags = f'<({tag_list})>'
end_tags = f'</({tag_list})>'
tokens = []
ner_tags = []
start = None
n = len(sentence)
raw_idx = 0
token_idx = 0
while raw_idx < n:
token = sentence[raw_idx]
if token == '<' and (m := re.match(start_tags, sentence[raw_idx:])):
start = token_idx
raw_idx += len(m.group(0))
elif token == '<' and (m := re.match(end_tags, sentence[raw_idx:])):
tag = m.group(1)
ner_tags[start] = label2id[f'B-{tag}']
for j in range(start+1, token_idx):
ner_tags[j] = label2id[f'I-{tag}']
raw_idx += len(m.group(0))
else:
tokens.append(token)
ner_tags.append(label2id['O'])
raw_idx += 1
token_idx += 1
return {'tokens': tokens, 'ner_tags': ner_tags}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment