Skip to content

Instantly share code, notes, and snippets.

@alexis779
Created June 7, 2024 04:18
Show Gist options
  • Save alexis779/7cd7d6b2d43991c11cbebe43afff0347 to your computer and use it in GitHub Desktop.
Save alexis779/7cd7d6b2d43991c11cbebe43afff0347 to your computer and use it in GitHub Desktop.
dataset processing crashed on A6000 with 32 GB advertised
# Make sure your huggingface token is set in ~/.cache/huggingface/token
# %%
from datasets import load_dataset
dataset = load_dataset("allenai/multi_lexsum", name="v20230518")
dataset
# %%
sample = dataset['train'][0]
sample
# %%
from transformers import AutoTokenizer
# make sure to approve the licence in the model page
model_id = 'mistralai/Mistral-7B-v0.3'
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'
# %%
prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{article}
### Response:
{summary}"""
instruction = "Summarize this document, giving the long version of the summary"
def format_sample(article, summary):
return prompt.format(instruction=instruction, article=article, summary=summary)
source_separator = '\n'
def format_samples(samples: dict[str, list]) -> dict[str, list]:
ids_list = samples['id']
sources_list = samples['sources']
summary_list = samples['summary/long']
text_list = [
format_sample(source_separator.join(sources), summary)
for sources, summary in zip(sources_list, summary_list)
]
inputs = tokenizer(text_list)
inputs['labels'] = inputs['input_ids']
return inputs
max_seq_length = 32768
def filter_samples(sample: dict[str, list]) -> list[bool]:
input_ids_list = sample['input_ids']
return [
len(input_ids) <= max_seq_length
for input_ids in input_ids_list
]
# %%
import pandas as pd
from datasets import DatasetDict
head_dataset = dataset
original_columns = dataset['train'].column_names
num_proc = 1
batch_size = 10
head_dataset = head_dataset.map(format_samples, batched=True, remove_columns=original_columns, batch_size=batch_size, writer_batch_size=batch_size)
head_dataset = head_dataset.filter(filter_samples, batched=True, batch_size=batch_size)
head_dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment