Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created March 20, 2023 00:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pszemraj/3633acb0cf3288d49b7bee550e756839 to your computer and use it in GitHub Desktop.
Save pszemraj/3633acb0cf3288d49b7bee550e756839 to your computer and use it in GitHub Desktop.
filter_dataset function reads datasets, counts tokens in specified columns, filters rows based on a minimum number of tokens, drops specified columns and/or rows with non-NaN values, and saves the modified datasets to a new directory. It returns summary statistics of the modified records.
from pathlib import Path
from tqdm.auto import tqdm
import pandas as pd
from nltk.tokenize import word_tokenize
tqdm.pandas()
#@title define `filter_dataset`
def filter_dataset(
data_dir: str,
new_dir: str,
min_word_count: int = 8,
min_word_column: str = "output",
drop_if_val_input: bool = False,
):
# Convert directory paths to Path objects
data_dir = Path(data_dir)
assert (
data_dir.exists() and data_dir.is_dir()
), f"{data_dir} is not a valid directory path."
new_dir = Path(new_dir)
new_dir.mkdir(parents=True, exist_ok=True)
# Load the datasets
train_data = pd.read_csv(data_dir / "train.csv")
test_data = pd.read_csv(data_dir / "test.csv")
validation_data = pd.read_csv(data_dir / "validation.csv")
# Define a function to count the number of words in each column
def count_words(df):
for col in df.columns:
df[col + "_word_count"] = df[col].progress_apply(
lambda x: len(word_tokenize(str(x)))
)
return df
# Count the tokens in each dataset
train_data = count_words(train_data)
test_data = count_words(test_data)
validation_data = count_words(validation_data)
# Filter the dataframes by min_word_column
assert (
min_word_column in train_data.columns
), f"{min_word_column} not found in train data columns."
assert (
min_word_column in test_data.columns
), f"{min_word_column} not found in test data columns."
assert (
min_word_column in validation_data.columns
), f"{min_word_column} not found in validation data columns."
train_data = train_data[
train_data[min_word_column + "_word_count"] >= min_word_count
]
test_data = test_data[test_data[min_word_column + "_word_count"] >= min_word_count]
validation_data = validation_data[
validation_data[min_word_column + "_word_count"] >= min_word_count
]
# Drop the token count columns
train_data.drop(
columns=[col for col in train_data.columns if col.endswith("_word_count")],
inplace=True,
)
test_data.drop(
columns=[col for col in test_data.columns if col.endswith("_word_count")],
inplace=True,
)
validation_data.drop(
columns=[
col for col in validation_data.columns if col.endswith("_word_count")
],
inplace=True,
)
# Drop rows with non-NaN values in the 'input' column if drop_if_val_input is True
if drop_if_val_input:
print("dropping rows with non-NA values in the 'input' column")
train_data = train_data[train_data["input"].isna()]
test_data = test_data[test_data["input"].isna()]
validation_data = validation_data[validation_data["input"].isna()]
del train_data["input"]
del test_data["input"]
del validation_data["input"]
# Save the new datasets to a new directory
train_data.to_csv(new_dir / "train.csv", index=False)
test_data.to_csv(new_dir / "test.csv", index=False)
validation_data.to_csv(new_dir / "validation.csv", index=False)
# Return the describe() of the token-counted dataframes as a dictionary with the split as the key
return {
"train": train_data.describe(),
"test": test_data.describe(),
"validation": validation_data.describe(),
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment