Created
June 28, 2022 12:33
-
-
Save March-08/1bff63505282bdd0f108109e5344e499 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NerDataset(torch.utils.data.Dataset): | |
""" | |
Custom dataset implementation to get (text,labels) tuples | |
Inputs: | |
- df : dataframe with columns [tags, sentence] | |
""" | |
def __init__(self, df): | |
if not isinstance(df, pd.DataFrame): | |
raise TypeError('Input should be a dataframe') | |
if "tags" not in df.columns or "sentence" not in df.columns: | |
raise ValueError("Dataframe should contain 'tags' and 'sentence' columns") | |
tags_list = [i.split() for i in df["tags"].values.tolist()] | |
texts = df["sentence"].values.tolist() | |
self.texts = [tokenizer(text, padding = "max_length", truncation = True, return_tensors = "pt") for text in texts] | |
self.labels = [match_tokens_labels(text, tags) for text,tags in zip(self.texts, tags_list)] | |
def __len__(self): | |
return len(self.labels) | |
def __getitem__(self, idx): | |
batch_text = self.texts[idx] | |
batch_labels = self.labels[idx] | |
return batch_text, torch.LongTensor(batch_labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment