Skip to content

Instantly share code, notes, and snippets.

@nebrelbug
Created November 7, 2023 23:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nebrelbug/6d37be23d2f6394348752f13bd2bd6d7 to your computer and use it in GitHub Desktop.
Save nebrelbug/6d37be23d2f6394348752f13bd2bd6d7 to your computer and use it in GitHub Desktop.
Rebuilding Alpaca with the Hugging Face Trainer Class
from datasets import load_dataset
IGNORE_TOKEN = -100
#####################
# FORMAT DATA #
#####################
template_context = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
"""
template_no_context = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
"""
def data_to_string(data):
instruction = data["instruction"]
context = data["input"]
response = data["output"]
template = template_context if len(context) > 0 else template_no_context
source = template.format(instruction=instruction, input=context)
return {
"source": source,
"text": source + response,
}
original_dataset = load_dataset("tatsu-lab/alpaca")["train"]
dataset = original_dataset.map(
data_to_string
).remove_columns(['instruction', 'input', 'output'])
#####################
# SPLIT DATA #
#####################
processed_dataset = dataset.train_test_split(test_size=0.1)
train_dataset = processed_dataset["train"]
eval_dataset = processed_dataset["test"]
#####################
# CREATE DATALOADER #
#####################
def data_collator(features, tokenizer):
sources = [feature["source"] for feature in features]
targets = [feature["text"] for feature in features]
source_tokens = tokenizer(
sources,
return_tensors="pt",
padding='longest',
max_length=None,
)
target_tokens = tokenizer(
targets,
return_tensors="pt",
padding='longest',
max_length=None,
)
labels = target_tokens["input_ids"].clone()
for i in range(len(labels)):
source_len = source_tokens["attention_mask"][i].sum()
labels[i, :source_len] = IGNORE_TOKEN
res = {
"input_ids": target_tokens["input_ids"],
"attention_mask": target_tokens["attention_mask"],
"labels": labels,
}
return res
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment