Last active
July 18, 2017 13:00
-
-
Save saraswatmks/3f898d4ab04acd6770283ba52d7f2abd to your computer and use it in GitHub Desktop.
normalized discounted checker
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import csv | |
import json | |
import sys | |
import numpy as np | |
LOCAL = True | |
if not LOCAL: | |
data = json.loads(raw_input()) | |
testcase = data['input_testcase_file_path'] | |
user_submission = data['submission_file_path'] | |
else: | |
testcase = 'path' | |
user_submission = 'path' | |
IDCG_SUM = sum([1/np.log2(x) for x in range(2, 12)]) | |
def verify_submission(): | |
if not user_submission.endswith('.csv'): | |
raise Exception('Please upload a csv containing predictions') | |
with open(user_submission, 'r') as fp_submission, open(testcase, 'r') as fp_testcase: | |
part_headers = ['Event' + str(_) for _ in range(1, 11)] | |
headers = ['PID'] + part_headers | |
testcase_headers = headers | |
sniffer = csv.Sniffer().sniff(fp_submission.read(1024), delimiters=',') | |
fp_submission.seek(0) | |
reader = csv.reader(fp_submission, sniffer) | |
submission_headers = reader.next() | |
for header in headers: | |
if header not in submission_headers: | |
raise Exception('File does not contain correct headers: {}'.format(', '.join(headers))) | |
submission_reader = csv.DictReader(fp_submission, fieldnames=headers, | |
delimiter=',') | |
testcase_reader = csv.DictReader(fp_testcase, fieldnames=testcase_headers) | |
submission_dict = {} | |
testcase_dict = {} | |
try: | |
for row in submission_reader: | |
submission_dict[row['PID']] = [row['Event'+str(_)] for _ in range(1, 11)] | |
except: | |
raise Exception('File does not contain correct headers') | |
for row in testcase_reader: | |
testcase_dict[row['PID']] = [row['Event'+str(_)] for _ in range(1, 11)] | |
submission_dict_keys = submission_dict.keys() | |
testcase_dict_keys = testcase_dict.keys() | |
global_ndcg = 0 | |
for testcase_dict_key in testcase_dict_keys: | |
if testcase_dict_key == 'PID': | |
continue | |
try: | |
predicted_events = submission_dict[testcase_dict_key] | |
actual_events = testcase_dict[testcase_dict_key] | |
#predicted_events = ['e1', 'e3', 'e7', 'e4', 'e5'] | |
#actual_events = ['e1', 'e2', 'e3', 'e4', 'e5'] | |
#print 'PID', testcase_dict_key | |
#print 'Predicted events', predicted_events | |
#print 'Actual events', actual_events | |
partial_dcgs = [] | |
predicted_event_flags = {} | |
for pred_index, predicted_event in enumerate(predicted_events): | |
try: | |
last_occ = predicted_event_flags.get(predicted_event, -1) | |
actual_index = actual_events[last_occ+1:].index(predicted_event) | |
effective_index = actual_index + last_occ + 1 | |
predicted_event_flags[predicted_event] = effective_index | |
predected_index = pred_index + 1 | |
actual_index = effective_index + 1 | |
#print 'Act Index', actual_index | |
#print 'Pred Index', predected_index | |
rel = min(float(actual_index)/predected_index, | |
float(predected_index)/actual_index) | |
#print 'REL', rel | |
pdcg = (float((2**rel)-1))/(np.log2(predected_index+1)) | |
partial_dcgs.append(pdcg) | |
except ValueError: | |
partial_dcgs.append(0) | |
#print predicted_event_flags | |
#print 'Partial DCGs', partial_dcgs | |
dcg = sum(partial_dcgs) | |
ndcg = dcg/IDCG_SUM | |
#print 'DCG', dcg, ndcg | |
except KeyError: | |
raise Exception('File does not contain prediction for {0}'.format(testcase_dict_key)) | |
global_ndcg += ndcg | |
mean = global_ndcg / len(testcase_dict_keys) | |
#mean = global_ndcg/5 | |
print '{0:.5f}'.format(mean) | |
try: | |
verify_submission() | |
except Exception, e: | |
#import traceback; print traceback.format_exc() | |
sys.exit(e.message) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment