Created
March 29, 2023 16:12
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from pathlib import Path | |
from tqdm.auto import tqdm | |
from transformers import AutoTokenizer | |
from datasets import load_dataset | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base") | |
# Utility functions | |
def calculate_token_length(df, colname="article"): | |
def get_token_length(text): | |
return len(tokenizer.encode(text, padding=False, truncation=False)) | |
df[f"{colname}_length"] = df[colname].progress_apply(get_token_length) | |
return df | |
def fix_punct_whitespace(text: str) -> str: | |
# Remove spaces before punctuation marks (except for parentheses) | |
text = re.sub(r"\s+([.,;:!?])", r"\1", text) | |
# Add a space after punctuation marks (except for parentheses) if missing | |
text = re.sub(r"([.,;:!?])(?=[^\s])", r"\1 ", text) | |
# Handle spaces around parentheses | |
text = re.sub(r"\s?\(\s?", r" (", text) | |
text = re.sub(r"\s?\)\s?", r")", text) | |
# Add a space after a closing parenthesis if: | |
# followed by a word or opening parenthesis | |
text = re.sub(r"\)(?=[^\s.,;:!?])", r") ", text) | |
# Handle spaces around quotation marks | |
text = re.sub(r'\s?"', r'"', text) | |
text = re.sub(r'"\s?', r'" ', text) | |
# Handle spaces around single quotes | |
text = re.sub(r"\s?'", r"'", text) | |
text = re.sub(r"'\s?", r"' ", text) | |
# Handle comma in numbers | |
text = re.sub(r"(\d),\s+(\d)", r"\1,\2", text) | |
return text | |
# Example usage | |
# cleaned_text = fix_punct_whitespace(sample2) | |
# Main function | |
def process_dataset(dataset_name, split_name, dataset_split="train"): | |
dataset = load_dataset(dataset_name, split_name) | |
short_dataset_name = dataset_name.split("/")[-1] | |
# Load data as pandas dataframe | |
data_df = dataset[dataset_split].to_pandas().convert_dtypes() | |
data_df["article"] = data_df["article"].progress_apply(fix_punct_whitespace) | |
data_df["summary"] = data_df["summary"].progress_apply(fix_punct_whitespace) | |
data_df = calculate_token_length(data_df, "article") | |
data_df = calculate_token_length(data_df, "summary") | |
# Save dataframes as parquet files | |
output_dir = Path(f"{short_dataset_name}-{split_name}") | |
output_dir.mkdir(parents=True, exist_ok=True) | |
output_path = output_dir / f"{dataset_split}.parquet") | |
data_df.to_parquet(output_path) | |
print(f"saved to:\n\t{str(output_path)}") | |
return data_df | |
# Main | |
splits = ["plos", "elife"] | |
dataset_splits = ["train", "test", "validation"] | |
for split in splits: | |
for dataset_split in dataset_splits: | |
_ = process_dataset( | |
"tomasg25/scientific_lay_summarisation", split, dataset_split | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment