Created
June 12, 2022 14:57
-
-
Save kaankarakeben/cd784391d43dbb3cbe998e6ad1a915a9 to your computer and use it in GitHub Desktop.
Encode dataset with BERT Tokenizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@pip_requirements(packages=["transformers"]) | |
@fabric("f-medium") | |
@model(name="bert-base-uncased-tokenizer") | |
def download_tokenizer(): | |
from transformers import BertTokenizerFast | |
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | |
return tokenizer | |
class PytorchDataset(Dataset): | |
def __init__(self, dataframe, tokenizer, tag_to_id, max_len): | |
self.len = len(dataframe) | |
self.data = dataframe | |
self.tokenizer = tokenizer | |
self.max_len = max_len | |
self.tag_to_id = tag_to_id | |
def __getitem__(self, index): | |
label_all_tokens = True | |
tokenized_inputs = self.tokenizer( | |
[list(self.data.tokens[index])], | |
truncation=True, | |
is_split_into_words=True, | |
max_length=128, | |
padding="max_length", | |
) | |
labels = [] | |
for i, label in enumerate([list(self.data.ner_tags[index])]): | |
word_ids = tokenized_inputs.word_ids(batch_index=i) | |
previous_word_idx = None | |
label_ids = [] | |
for word_idx in word_ids: | |
if word_idx is None: | |
label_ids.append(-100) | |
elif label[word_idx] == "0": | |
label_ids.append(0) | |
elif word_idx != previous_word_idx: | |
label_ids.append(self.tag_to_id[label[word_idx]]) | |
else: | |
label_ids.append(self.tag_to_id[label[word_idx]] if label_all_tokens else -100) | |
previous_word_idx = word_idx | |
labels.append(label_ids) | |
tokenized_inputs["labels"] = labels | |
single_tokenized_input = {} | |
for k, v in tokenized_inputs.items(): | |
single_tokenized_input[k] = torch.as_tensor(v[0]) | |
return single_tokenized_input | |
def __len__(self): | |
return self.len | |
def create_model_inputs(dataset, tag_to_id): | |
train_dataset = dataset.sample(frac=TRAIN_EXAMPLES_RATIO, random_state=200) | |
test_dataset = dataset.drop(train_dataset.index).reset_index(drop=True) | |
train_dataset = train_dataset.reset_index(drop=True) | |
print("FULL Dataset: {}".format(dataset.shape)) | |
print("TRAIN Dataset: {}".format(train_dataset.shape)) | |
print("TEST Dataset: {}".format(test_dataset.shape)) | |
train = PytorchDataset(train_dataset, tokenizer, tag_to_id, MAX_LEN) | |
test = PytorchDataset(test_dataset, tokenizer, tag_to_id, MAX_LEN) | |
return train, test | |
tokenizer = download_tokenizer() | |
tag_counter = Counter([tag for tags in ner_dataset["ner_tags"] for tag in tags]) | |
tag_to_id = {tag: ix for ix, tag in enumerate(tag_counter.keys())} | |
train_set, test_set = create_model_inputs(ner_dataset, tag_to_id) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment