Skip to content

Instantly share code, notes, and snippets.

@ben0it8
Last active July 18, 2019 14:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ben0it8/1546fbfa34b29f4c25c88ab9795eb92f to your computer and use it in GitHub Desktop.
Save ben0it8/1546fbfa34b29f4c25c88ab9795eb92f to your computer and use it in GitHub Desktop.
Create bert textprocessor
import torch
from torch.utils.data import TensorDataset, random_split, DataLoader
import numpy as np
import warnings
from tqdm import tqdm_notebook as tqdm
from typing import Tuple
NUM_MAX_POSITIONS = 256
BATCH_SIZE = 32
class TextProcessor:
# special tokens for classification and padding
CLS = '[CLS]'
PAD = '[PAD]'
def __init__(self, tokenizer, label2id: dict, num_max_positions:int=512):
self.tokenizer=tokenizer
self.label2id = label2id
self.num_labels = len(label2id)
self.num_max_positions = num_max_positions
def process_example(self, example: Tuple[str, str]):
"Convert text (example[0]) to sequence of IDs and label (example[1] to integer"
assert len(example) == 2
label, text = example[0], example[1]
assert isinstance(text, str)
tokens = self.tokenizer.tokenize(text)
# truncate if too long
if len(tokens) >= self.num_max_positions:
tokens = tokens[:self.num_max_positions-1]
ids = self.tokenizer.convert_tokens_to_ids(tokens) + [self.tokenizer.vocab[self.CLS]]
# pad if too short
else:
pad = [self.tokenizer.vocab[self.PAD]] * (self.num_max_positions-len(tokens)-1)
ids = self.tokenizer.convert_tokens_to_ids(tokens) + [self.tokenizer.vocab[self.CLS]] + pad
return ids, self.label2id[label]
# download the 'bert-base-cased' tokenizer
from pytorch_transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
# initialize a TextProcessor
processor = TextProcessor(tokenizer, label2int, num_max_positions=NUM_MAX_POSITIONS)
@ben0it8
Copy link
Author

ben0it8 commented Jul 11, 2019

# create train and valid sets by splitting
train_dl, valid_dl = create_dataloaders(datasets["train"], processor, 
                                    batch_size=finetuning_config.batch_size, 
                                    valid_pct=finetuning_config.valid_pct)

test_dl = create_dataloaders(datasets["test"], processor, 
                             batch_size=finetuning_config.batch_size, 
                             valid_pct=None)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment