Created
December 2, 2021 20:29
-
-
Save lewtun/402ca5608ad90b07e9fcb4a811b06d41 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
def validate_datasets(reference_dataset, new_dataset): | |
"""Validate the column names and rows of the new dataset""" | |
splits = list(reference_dataset.keys()) | |
for split in splits: | |
ref_dset = reference_dataset[split] | |
new_dset = new_dataset[split] | |
# Check column names agree | |
ref_cols = set(ref_dset.column_names) | |
new_cols = set(new_dset.column_names) | |
mismatched_cols = ref_cols.difference(new_cols) | |
if mismatched_cols: | |
print( | |
f"Column names for split {split} do not agree! Mismatched columns: {mismatched_cols}" | |
) | |
# Check number of rows agree | |
ref_rows = ref_dset.num_rows | |
new_rows = new_dset.num_rows | |
if ref_rows != new_rows: | |
print( | |
f"Number of rows for split {split} do not agree! Reference dataset has {rows1} rows, " | |
) | |
print("Column names and number of rows match!") | |
def main(): | |
ref_configs = ["mlsum_de", "mlsum_es"] | |
new_configs = ["de", "es"] | |
for ref_config, new_config in zip(ref_configs, new_configs): | |
print(f"Validating GEM config {ref_config}...") | |
ref_dataset = load_dataset("gem", name=ref_config) | |
new_dataset = load_dataset("GEM/mlsum", name=new_config) | |
validate_datasets(ref_dataset, new_dataset) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment