Skip to content

Instantly share code, notes, and snippets.

@codesankalp
Last active November 11, 2022 05:44
Show Gist options
  • Save codesankalp/dbb63c8b13dca8641a1dcaedcfe84ed5 to your computer and use it in GitHub Desktop.
Save codesankalp/dbb63c8b13dca8641a1dcaedcfe84ed5 to your computer and use it in GitHub Desktop.
import argparse
import json
import os
import subprocess
import sys
import traceback
from importlib import import_module
from typing import Dict, Final, List
EXPECTED_COLUMN: Final = "expected_mapped_column"
SUBMISSION_COLUMN: Final = "submission_mapped_column"
NAME: Final = "name"
ID: Final = "id"
THIRD_PARTY_PACKAGES: Dict = {
"pandas": "pandas==1.4.2",
"sklearn": "scikit-learn==1.1.1",
}
try:
[import_module(package) for package in THIRD_PARTY_PACKAGES.keys()]
except ModuleNotFoundError:
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
*THIRD_PARTY_PACKAGES.values(),
]
)
finally:
import pandas as pd
from sklearn.metrics import f1_score
class MetricException(Exception):
def __str__(self) -> str:
return f"""
########################################
### {self.__class__.__name__} ###
{self.message}
########################################
"""
class InvalidSubmissionFile(MetricException):
message = (
"This error is raised when the submission file is not valid."
"i.e. the submission file does not contain the expected rows."
)
class ColumnsNotPresent(MetricException):
def __init__(self, columns: List[str]) -> None:
self.message = f"Columns not present in CSV file: {columns}"
class InvalidMapping(MetricException):
def __init__(self, mapping: Dict) -> None:
self.message = f"Duplicate column names in mapping: {mapping}"
def get_args(description: str) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=description)
script_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument(
"--metric",
type=str,
default=os.path.join(script_dir, "metric.json"),
help="Path to metric.json",
)
parser.add_argument(
"--expected",
type=str,
help="Path to expected csv file",
required=True,
)
parser.add_argument(
"--submission",
type=str,
help="Path to submission csv file",
required=True,
)
return parser.parse_args()
def get_metric(metric_file_path: str) -> Dict:
metric_json = open(metric_file_path, "r")
return json.load(metric_json)
def get_column_mapping(columns: List[Dict], column_name: str) -> List[Dict]:
mapping: Dict = dict()
for column in columns:
mapping_name = column.get(NAME).lower()
mapping[mapping_name] = column.get(column_name)
return mapping
def validate_columns(df: pd.DataFrame, mapping: Dict) -> None:
if len(set(mapping.values())) != len(set(mapping.keys())):
raise InvalidMapping(mapping)
if not set(mapping.values()).issubset(set(df.columns)):
raise ColumnsNotPresent(mapping.values())
def get_csv_data(file_path: str, mapping: Dict) -> pd.DataFrame:
df = pd.read_csv(file_path)
validate_columns(df, mapping)
invert_mapping = {v: k for k, v in mapping.items()}
df.rename(columns=invert_mapping, inplace=True)
df.drop_duplicates(subset=[ID], inplace=True)
df.sort_values(by=ID, inplace=True)
return df
def check_valid_submission_file(
expected: pd.DataFrame, submission: pd.DataFrame
) -> None:
submission_id_set = set(submission[ID].iloc())
expected_id_set = set(expected[ID].iloc())
if not submission_id_set.issubset(expected_id_set):
raise InvalidSubmissionFile
def get_score(
expected_file_path: str, submission_file_path: str, metric: Dict
) -> float:
columns = metric.get("columns", [])
expected_column_mapping = get_column_mapping(columns, EXPECTED_COLUMN)
submission_column_mapping = get_column_mapping(columns, SUBMISSION_COLUMN)
expected_df = get_csv_data(expected_file_path, expected_column_mapping)
submission_df = get_csv_data(
submission_file_path,
submission_column_mapping,
)
check_valid_submission_file(expected_df, submission_df)
merged_df = pd.merge(expected_df, submission_df, on=ID, how="left")
try:
return f1_score(
**{
"y_true": merged_df["expected_x"],
"y_pred": merged_df["expected_y"],
}
)
except ValueError as err:
print(str(err))
# if there is no data or partial data in the submission file
# it means the submission file is not valid
raise InvalidSubmissionFile
if __name__ == "__main__":
score = 0
try:
args = get_args(description="Mean F-Score")
metric = get_metric(args.metric)
expected_file_path = args.expected
submission_file_path = args.submission
score = get_score(
expected_file_path,
submission_file_path,
metric,
)
print(f"FS_SCORE: {score*100}%")
except MetricException as err:
print(str(err))
except Exception:
traceback.print_exc()
print("FS_SCORE: 0%")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment