Created
October 18, 2022 04:36
-
-
Save codesankalp/0204b2e3a83696d5d66a604b7af83d91 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import os | |
import subprocess | |
import sys | |
from importlib import import_module | |
from typing import Dict, Final, List | |
EXPECTED_COLUMN: Final = "expected_mapped_column" | |
SUBMISSION_COLUMN: Final = "submission_mapped_column" | |
NAME: Final = "name" | |
ID: Final = "id" | |
THIRD_PARTY_PACKAGES: Dict = { | |
"pandas": "pandas==1.4.2", | |
"sklearn": "scikit-learn==1.1.1", | |
} | |
try: | |
[import_module(package) for package in THIRD_PARTY_PACKAGES.keys()] | |
except ModuleNotFoundError: | |
subprocess.check_call( | |
[ | |
sys.executable, | |
"-m", | |
"pip", | |
"install", | |
*THIRD_PARTY_PACKAGES.values(), | |
] | |
) | |
finally: | |
import pandas as pd | |
from sklearn.metrics import balanced_accuracy_score | |
class InvalidSubmissionFile(Exception): | |
pass | |
def get_args(description: str) -> argparse.Namespace: | |
parser = argparse.ArgumentParser(description=description) | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
parser.add_argument( | |
"--metric", | |
type=str, | |
default=os.path.join(script_dir, "metric.json"), | |
help="Path to metric.json", | |
) | |
parser.add_argument( | |
"--expected", | |
type=str, | |
help="Path to expected csv file", | |
required=True, | |
) | |
parser.add_argument( | |
"--submission", | |
type=str, | |
help="Path to submission csv file", | |
required=True, | |
) | |
return parser.parse_args() | |
def get_metric(metric_file_path: str) -> Dict: | |
metric_json = open(metric_file_path, "r") | |
return json.load(metric_json) | |
def get_column_mapping(columns: List[Dict], column_name: str) -> List[Dict]: | |
mapping: Dict = dict() | |
for column in columns: | |
mapping_name = column.get(NAME).lower() | |
mapping[mapping_name] = column.get(column_name) | |
return mapping | |
def get_csv_data(file_path: str, mapping: Dict) -> pd.DataFrame: | |
df = pd.read_csv(file_path) | |
invert_mapping = {v: k for k, v in mapping.items()} | |
df.rename(columns=invert_mapping, inplace=True) | |
df.sort_values(by=ID, inplace=True) | |
return df | |
def get_score( | |
expected_file_path: str, submission_file_path: str, metric: Dict | |
) -> float: | |
columns = metric.get("columns", []) | |
expected_column_mapping = get_column_mapping(columns, EXPECTED_COLUMN) | |
submission_column_mapping = get_column_mapping(columns, SUBMISSION_COLUMN) | |
expected_df = get_csv_data(expected_file_path, expected_column_mapping) | |
submission_df = get_csv_data( | |
submission_file_path, | |
submission_column_mapping, | |
) | |
merged_df = pd.merge(expected_df, submission_df, on=ID, how="left") | |
try: | |
return balanced_accuracy_score( | |
**{ | |
"y_true": merged_df["expected_x"], | |
"y_pred": merged_df["expected_y"], | |
} | |
) | |
except ValueError as err: | |
print(str(err)) | |
# if there is no data or partial data in the submission file | |
# it means the submission file is not valid | |
raise InvalidSubmissionFile("Invalid submission file") | |
if __name__ == "__main__": | |
score = 0 | |
try: | |
args = get_args(description="Balanced Accuracy") | |
metric = get_metric(args.metric) | |
expected_file_path = args.expected | |
submission_file_path = args.submission | |
score = get_score( | |
expected_file_path, | |
submission_file_path, | |
metric, | |
) | |
except Exception as err: | |
print(str(err)) | |
finally: | |
print(f"FS_SCORE: {score*100}%") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment