Created
March 20, 2023 00:47
-
-
Save pszemraj/3633acb0cf3288d49b7bee550e756839 to your computer and use it in GitHub Desktop.
filter_dataset function reads datasets, counts tokens in specified columns, filters rows based on a minimum number of tokens, drops specified columns and/or rows with non-NaN values, and saves the modified datasets to a new directory. It returns summary statistics of the modified records.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
from tqdm.auto import tqdm | |
import pandas as pd | |
from nltk.tokenize import word_tokenize | |
tqdm.pandas() | |
#@title define `filter_dataset` | |
def filter_dataset( | |
data_dir: str, | |
new_dir: str, | |
min_word_count: int = 8, | |
min_word_column: str = "output", | |
drop_if_val_input: bool = False, | |
): | |
# Convert directory paths to Path objects | |
data_dir = Path(data_dir) | |
assert ( | |
data_dir.exists() and data_dir.is_dir() | |
), f"{data_dir} is not a valid directory path." | |
new_dir = Path(new_dir) | |
new_dir.mkdir(parents=True, exist_ok=True) | |
# Load the datasets | |
train_data = pd.read_csv(data_dir / "train.csv") | |
test_data = pd.read_csv(data_dir / "test.csv") | |
validation_data = pd.read_csv(data_dir / "validation.csv") | |
# Define a function to count the number of words in each column | |
def count_words(df): | |
for col in df.columns: | |
df[col + "_word_count"] = df[col].progress_apply( | |
lambda x: len(word_tokenize(str(x))) | |
) | |
return df | |
# Count the tokens in each dataset | |
train_data = count_words(train_data) | |
test_data = count_words(test_data) | |
validation_data = count_words(validation_data) | |
# Filter the dataframes by min_word_column | |
assert ( | |
min_word_column in train_data.columns | |
), f"{min_word_column} not found in train data columns." | |
assert ( | |
min_word_column in test_data.columns | |
), f"{min_word_column} not found in test data columns." | |
assert ( | |
min_word_column in validation_data.columns | |
), f"{min_word_column} not found in validation data columns." | |
train_data = train_data[ | |
train_data[min_word_column + "_word_count"] >= min_word_count | |
] | |
test_data = test_data[test_data[min_word_column + "_word_count"] >= min_word_count] | |
validation_data = validation_data[ | |
validation_data[min_word_column + "_word_count"] >= min_word_count | |
] | |
# Drop the token count columns | |
train_data.drop( | |
columns=[col for col in train_data.columns if col.endswith("_word_count")], | |
inplace=True, | |
) | |
test_data.drop( | |
columns=[col for col in test_data.columns if col.endswith("_word_count")], | |
inplace=True, | |
) | |
validation_data.drop( | |
columns=[ | |
col for col in validation_data.columns if col.endswith("_word_count") | |
], | |
inplace=True, | |
) | |
# Drop rows with non-NaN values in the 'input' column if drop_if_val_input is True | |
if drop_if_val_input: | |
print("dropping rows with non-NA values in the 'input' column") | |
train_data = train_data[train_data["input"].isna()] | |
test_data = test_data[test_data["input"].isna()] | |
validation_data = validation_data[validation_data["input"].isna()] | |
del train_data["input"] | |
del test_data["input"] | |
del validation_data["input"] | |
# Save the new datasets to a new directory | |
train_data.to_csv(new_dir / "train.csv", index=False) | |
test_data.to_csv(new_dir / "test.csv", index=False) | |
validation_data.to_csv(new_dir / "validation.csv", index=False) | |
# Return the describe() of the token-counted dataframes as a dictionary with the split as the key | |
return { | |
"train": train_data.describe(), | |
"test": test_data.describe(), | |
"validation": validation_data.describe(), | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment