Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Last active August 30, 2021 10:53
Show Gist options
  • Save seanbenhur/8f015588a10aabfc7a36d954ff7a24c6 to your computer and use it in GitHub Desktop.
Save seanbenhur/8f015588a10aabfc7a36d954ff7a24c6 to your computer and use it in GitHub Desktop.
from datasets import load_dataset
from transformers import AutoTokenizer
#load the dataset
dataset = load_dataset("imdb")
#create tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
def encode_batch(batch):
"""Encodes a batch of input data using the model tokenizer."""
return tokenizer(batch["text"], max_length=80, truncation=True, padding="max_length")
# Encode the input data
dataset = dataset.map(encode_batch, batched=True)
# The transformers model expects the target class column to be named "labels"
dataset.rename_column_("label", "labels")
# Transform to pytorch tensors and only output the required columns
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment