Skip to content

Instantly share code, notes, and snippets.

@MichalMalyska
Created April 28, 2020 21:04
Show Gist options
  • Save MichalMalyska/50387452d7eb842175d97a8a7d7601f9 to your computer and use it in GitHub Desktop.
Save MichalMalyska/50387452d7eb842175d97a8a7d7601f9 to your computer and use it in GitHub Desktop.
import os
import pandas as pd
from typing import Dict, List, Iterator, Tuple, Union
import logging
import torch
from overrides import overrides
from transformers import BertTokenizerFast
# AllenNLP imports
from allennlp.data import Instance
from allennlp.data.fields import LabelField, TextField, MetadataField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.token_indexers import TokenIndexer, PretrainedTransformerIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
@DatasetReader.register('ms_edss19_reader')
class ms_edss19_reader(DatasetReader):
def __init__(self, tokenizer:str = "BertTokenizerFast", token_indexers: Dict[str, TokenIndexer] = None, **kwargs) -> None:
super().__init__(lazy=False)
self.token_indexers = token_indexers or {"tokens": PretrainedTransformerIndexer}
if tokenizer == "BertTokenizerFast":
self.tokenizer = BertTokenizerFast("/models/base_blue_bert_pt/vocab.txt")
else:
raise NotImplementedError
def text_to_instance(self, text: str, ids: int, labels: float = None) -> Instance:
text_ids = []
for t in text[1:-1].split(','):
text_ids.append(int(t))
tokens = [Token(text_id=x) for x in text_ids]
note_field = TextField(tokens, self.token_indexers)
fields = {"tokens": note_field}
id_field = MetadataField([ids])
fields["ids"] = id_field
if labels:
label_field = LabelField(str(labels), label_namespace="edss19_labels")
fields["label"] = label_field
else:
label_field = LabelField(str(0.0), label_namespace="edss19_labels")
fields["label"] = label_field
return Instance(fields)
def _read(self, file_path: str) -> Iterator[Instance]:
df = pd.read_csv(file_path)
for i, row in df.iterrows():
if row["tokenized_text"] == "[101, 102]" or row["edss_19"] == '' or row["edss_19"] is None:
continue
if row["edss_19"] < 0 :
continue
label = row["edss_19"]
yield self.text_to_instance(text=row["tokenized_text"], ids=row["patient_id"], labels = label)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment