-
-
Save alexis779/7cd7d6b2d43991c11cbebe43afff0347 to your computer and use it in GitHub Desktop.
dataset processing crashed on A6000 with 32 GB advertised
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Make sure your huggingface token is set in ~/.cache/huggingface/token | |
# %% | |
from datasets import load_dataset | |
dataset = load_dataset("allenai/multi_lexsum", name="v20230518") | |
dataset | |
# %% | |
sample = dataset['train'][0] | |
sample | |
# %% | |
from transformers import AutoTokenizer | |
# make sure to approve the licence in the model page | |
model_id = 'mistralai/Mistral-7B-v0.3' | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = 'right' | |
# %% | |
prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
### Instruction: | |
{instruction} | |
### Input: | |
{article} | |
### Response: | |
{summary}""" | |
instruction = "Summarize this document, giving the long version of the summary" | |
def format_sample(article, summary): | |
return prompt.format(instruction=instruction, article=article, summary=summary) | |
source_separator = '\n' | |
def format_samples(samples: dict[str, list]) -> dict[str, list]: | |
ids_list = samples['id'] | |
sources_list = samples['sources'] | |
summary_list = samples['summary/long'] | |
text_list = [ | |
format_sample(source_separator.join(sources), summary) | |
for sources, summary in zip(sources_list, summary_list) | |
] | |
inputs = tokenizer(text_list) | |
inputs['labels'] = inputs['input_ids'] | |
return inputs | |
max_seq_length = 32768 | |
def filter_samples(sample: dict[str, list]) -> list[bool]: | |
input_ids_list = sample['input_ids'] | |
return [ | |
len(input_ids) <= max_seq_length | |
for input_ids in input_ids_list | |
] | |
# %% | |
import pandas as pd | |
from datasets import DatasetDict | |
head_dataset = dataset | |
original_columns = dataset['train'].column_names | |
num_proc = 1 | |
batch_size = 10 | |
head_dataset = head_dataset.map(format_samples, batched=True, remove_columns=original_columns, batch_size=batch_size, writer_batch_size=batch_size) | |
head_dataset = head_dataset.filter(filter_samples, batched=True, batch_size=batch_size) | |
head_dataset |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment