Skip to content

Instantly share code, notes, and snippets.

@saraswatmks
Last active July 18, 2017 13:00
Show Gist options
  • Save saraswatmks/3f898d4ab04acd6770283ba52d7f2abd to your computer and use it in GitHub Desktop.
Save saraswatmks/3f898d4ab04acd6770283ba52d7f2abd to your computer and use it in GitHub Desktop.
normalized discounted checker
#!/usr/bin/env python
import csv
import json
import sys
import numpy as np
LOCAL = True
if not LOCAL:
data = json.loads(raw_input())
testcase = data['input_testcase_file_path']
user_submission = data['submission_file_path']
else:
testcase = 'path'
user_submission = 'path'
IDCG_SUM = sum([1/np.log2(x) for x in range(2, 12)])
def verify_submission():
if not user_submission.endswith('.csv'):
raise Exception('Please upload a csv containing predictions')
with open(user_submission, 'r') as fp_submission, open(testcase, 'r') as fp_testcase:
part_headers = ['Event' + str(_) for _ in range(1, 11)]
headers = ['PID'] + part_headers
testcase_headers = headers
sniffer = csv.Sniffer().sniff(fp_submission.read(1024), delimiters=',')
fp_submission.seek(0)
reader = csv.reader(fp_submission, sniffer)
submission_headers = reader.next()
for header in headers:
if header not in submission_headers:
raise Exception('File does not contain correct headers: {}'.format(', '.join(headers)))
submission_reader = csv.DictReader(fp_submission, fieldnames=headers,
delimiter=',')
testcase_reader = csv.DictReader(fp_testcase, fieldnames=testcase_headers)
submission_dict = {}
testcase_dict = {}
try:
for row in submission_reader:
submission_dict[row['PID']] = [row['Event'+str(_)] for _ in range(1, 11)]
except:
raise Exception('File does not contain correct headers')
for row in testcase_reader:
testcase_dict[row['PID']] = [row['Event'+str(_)] for _ in range(1, 11)]
submission_dict_keys = submission_dict.keys()
testcase_dict_keys = testcase_dict.keys()
global_ndcg = 0
for testcase_dict_key in testcase_dict_keys:
if testcase_dict_key == 'PID':
continue
try:
predicted_events = submission_dict[testcase_dict_key]
actual_events = testcase_dict[testcase_dict_key]
#predicted_events = ['e1', 'e3', 'e7', 'e4', 'e5']
#actual_events = ['e1', 'e2', 'e3', 'e4', 'e5']
#print 'PID', testcase_dict_key
#print 'Predicted events', predicted_events
#print 'Actual events', actual_events
partial_dcgs = []
predicted_event_flags = {}
for pred_index, predicted_event in enumerate(predicted_events):
try:
last_occ = predicted_event_flags.get(predicted_event, -1)
actual_index = actual_events[last_occ+1:].index(predicted_event)
effective_index = actual_index + last_occ + 1
predicted_event_flags[predicted_event] = effective_index
predected_index = pred_index + 1
actual_index = effective_index + 1
#print 'Act Index', actual_index
#print 'Pred Index', predected_index
rel = min(float(actual_index)/predected_index,
float(predected_index)/actual_index)
#print 'REL', rel
pdcg = (float((2**rel)-1))/(np.log2(predected_index+1))
partial_dcgs.append(pdcg)
except ValueError:
partial_dcgs.append(0)
#print predicted_event_flags
#print 'Partial DCGs', partial_dcgs
dcg = sum(partial_dcgs)
ndcg = dcg/IDCG_SUM
#print 'DCG', dcg, ndcg
except KeyError:
raise Exception('File does not contain prediction for {0}'.format(testcase_dict_key))
global_ndcg += ndcg
mean = global_ndcg / len(testcase_dict_keys)
#mean = global_ndcg/5
print '{0:.5f}'.format(mean)
try:
verify_submission()
except Exception, e:
#import traceback; print traceback.format_exc()
sys.exit(e.message)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment