Skip to content

Instantly share code, notes, and snippets.

@lewtun
Created December 2, 2021 20:29
Show Gist options
  • Save lewtun/402ca5608ad90b07e9fcb4a811b06d41 to your computer and use it in GitHub Desktop.
Save lewtun/402ca5608ad90b07e9fcb4a811b06d41 to your computer and use it in GitHub Desktop.
from datasets import load_dataset
def validate_datasets(reference_dataset, new_dataset):
"""Validate the column names and rows of the new dataset"""
splits = list(reference_dataset.keys())
for split in splits:
ref_dset = reference_dataset[split]
new_dset = new_dataset[split]
# Check column names agree
ref_cols = set(ref_dset.column_names)
new_cols = set(new_dset.column_names)
mismatched_cols = ref_cols.difference(new_cols)
if mismatched_cols:
print(
f"Column names for split {split} do not agree! Mismatched columns: {mismatched_cols}"
)
# Check number of rows agree
ref_rows = ref_dset.num_rows
new_rows = new_dset.num_rows
if ref_rows != new_rows:
print(
f"Number of rows for split {split} do not agree! Reference dataset has {rows1} rows, "
)
print("Column names and number of rows match!")
def main():
ref_configs = ["mlsum_de", "mlsum_es"]
new_configs = ["de", "es"]
for ref_config, new_config in zip(ref_configs, new_configs):
print(f"Validating GEM config {ref_config}...")
ref_dataset = load_dataset("gem", name=ref_config)
new_dataset = load_dataset("GEM/mlsum", name=new_config)
validate_datasets(ref_dataset, new_dataset)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment