Skip to content

Instantly share code, notes, and snippets.

@March-08
Created June 28, 2022 12:33
Show Gist options
  • Save March-08/1bff63505282bdd0f108109e5344e499 to your computer and use it in GitHub Desktop.
Save March-08/1bff63505282bdd0f108109e5344e499 to your computer and use it in GitHub Desktop.
class NerDataset(torch.utils.data.Dataset):
"""
Custom dataset implementation to get (text,labels) tuples
Inputs:
- df : dataframe with columns [tags, sentence]
"""
def __init__(self, df):
if not isinstance(df, pd.DataFrame):
raise TypeError('Input should be a dataframe')
if "tags" not in df.columns or "sentence" not in df.columns:
raise ValueError("Dataframe should contain 'tags' and 'sentence' columns")
tags_list = [i.split() for i in df["tags"].values.tolist()]
texts = df["sentence"].values.tolist()
self.texts = [tokenizer(text, padding = "max_length", truncation = True, return_tensors = "pt") for text in texts]
self.labels = [match_tokens_labels(text, tags) for text,tags in zip(self.texts, tags_list)]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
batch_text = self.texts[idx]
batch_labels = self.labels[idx]
return batch_text, torch.LongTensor(batch_labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment