Created
April 29, 2024 12:14
-
-
Save NirantK/8c2e763ced9163526b8af0d5f43a2f3e to your computer and use it in GitHub Desktop.
Validation function for Qrels. Helpful for working with TREC Tools
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def validate_data(predictions, references): | |
# Define expected fields and types for predictions and references | |
expected_pred_keys = { | |
'query': int, 'q0': str, 'docid': str, 'rank': int, 'score': float, 'system': str | |
} | |
expected_ref_keys = { | |
'query': int, 'q0': str, 'docid': str, 'rel': int | |
} | |
# Function to validate each record against expected fields and types | |
def check_record(record, expected_keys): | |
for key, expected_type in expected_keys.items(): | |
if key not in record: | |
return f"Missing key: {key}" | |
if not all(isinstance(item, expected_type) for item in record[key]): | |
return f"Incorrect type for key {key}, expected {expected_type}, got {type(record[key][0])}" | |
# Check for consistent lengths across fields | |
length = len(record[next(iter(record))]) # get length of first item | |
if not all(len(value) == length for value in record.values()): | |
return "Inconsistent lengths among fields" | |
return "Valid" | |
# Validate predictions and references | |
pred_validation = check_record(predictions, expected_pred_keys) | |
ref_validation = check_record(references, expected_ref_keys) | |
return pred_validation, ref_validation |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment