Skip to content

Instantly share code, notes, and snippets.

@noam1023
Created December 21, 2021 13:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save noam1023/c7750a104d97afb7e2c7aa1474b1f29b to your computer and use it in GitHub Desktop.
Save noam1023/c7750a104d97afb7e2c7aa1474b1f29b to your computer and use it in GitHub Desktop.
# check the results of DSLAB HW2
# expected input:
# check_hw2.py industry2cluster_123456789_987654312.csv company2cluster_123456789_987654312.csv golden.csv
import random
import sys
import csv
def read_industries(fn) -> dict:
"""read the mapping industry -> cluster"""
r = {}
with open(fn, 'r') as input:
reader = csv.reader(input, quoting=csv.QUOTE_ALL)
next(reader) # skip header line
for row in reader:
industry, id = row
id = int(id.strip())
r[industry] = id
return r
def read_golden_industry(fn) -> dict:
"""
read the CSV file containing <int companyID> -> <string (with potential coma) industry name>
:param fn: CSV file name
:return: dict{ companyId -> industry name)
"""
r = {}
with open(fn, 'r', newline='') as input:
reader = csv.reader(input, quoting=csv.QUOTE_ALL)
for row in reader:
companyId, industry = row
r[companyId] = industry
return r
def incr(aDict, key ):
if key not in aDict:
aDict[key] = 0
else:
aDict[key] += 1
def check(industry2cluster_fn, company2_cluster_fn, golden_company_industry_fn):
"""
Given the data files from the user, compare the results using the true industry of each company.
:return: number of correct (relative to the golden values of industry) of mapping a company to industry to the SAME cluster.
"""
num_matching = 0
golden_company_industry = read_golden_industry(golden_company_industry_fn)
industry2cluster = read_industries(industry2cluster_fn)
pred_company2cluster = [dict() for _ in range(21)] # can't use [{}]*21 because it will be the same dict...
with open(company2_cluster_fn, 'r') as company_csv:
reader = csv.reader(company_csv, quoting=csv.QUOTE_ALL)
next(reader) # skip header line
total_data_points = 0
for row in reader:
comp_id, pred_cluster_id = row
pred_cluster_id = int(pred_cluster_id.strip())
total_data_points += 1
try:
true_industry = golden_company_industry[comp_id]
except KeyError:
print("company ID %d not found" % int(comp_id))
continue
clusterID_from_industry = industry2cluster[true_industry]
num_matching += clusterID_from_industry == pred_cluster_id
incr(pred_company2cluster[pred_cluster_id],true_industry)
# The following line is a 'circular reasoning' since we use
# the true_industry both as key and value
# incr(pred_company2cluster[clusterID_from_industry], true_industry)
success_pct = 100*num_matching/total_data_points
print("correctly identified: %d = %d%%"% (num_matching, int(success_pct)) )
factor = check_cluster_sizes(pred_company2cluster[1:]) # get rid of the zero-th element which was just to keep indexing nice
factor = max(factor,0.7) # don't be too cruel
print("Penalty factor due to cluster sizes: %d%%" % (factor*100))
return success_pct * factor
def check_cluster_sizes(pred: list):
""" we want that for each of the 20 clusters:
number of industries be in the range 4 to 15 (which is 3% and 10% of 147 industries)
Each cluster contains company IDs (there are 200K data points in the test data).
For each cluster, we want to verify that there are between 4 and 15 industries.
This is a bit confusing, since the cluster size counts how many companies are in this cluster,
and the range requirement is on a feature of the company
:param: list of dict . For each cluster we keep a dictionary of industry-> count of companies that REALLY have this industry
:return: factor in [0,1.0] according to how good this requirement is fulfilled.
"""
penalty_scale = 0.005 # reduce the score by N% for each 1% outlier ( e.g. if 12% -> 12-10 == 2 --> 2*N )
# create a list with 20 entries: for each cluster, how many industries are represented in it.
cluster_num_industries = [ len(ind.keys()) for ind in pred]
print("Cluster sizes:", cluster_num_industries)
outliers = [ x for x in cluster_num_industries if x < 4 or x > 15]
if len(outliers) > 0:
print("outliers %", outliers)
penalty = sum([ 4-t if t < 4 else t-15 for t in outliers]) * penalty_scale
# # to get rough idea how the cluster sizes are distributed
# import numpy as np
# mean = np.mean(cluster_size)
# std = np.std(cluster_size)
# if std/mean > 0.1 :
# print("cluster stats: mean=%d std=%d" % (mean,std))
return 1.0 - penalty
def _gen_industry2cluster(golden_company_industry_fn):
""" generate a fake result industry (string) -> cluster (int [1..20])"""
golden_company_industry = read_golden_industry(golden_company_industry_fn)
industries = set(golden_company_industry.values())
print("found %d industries" % len(industries))
with open('industry2cluster.csv', 'w') as fout:
for name in industries:
fout.write('"%s",%d\n' % (name, random.randint(1, 20)))
def _gen_company2cluster(num_companies):
""" create fake table of company ID -> cluster ID"""
import random
offset = 1500000
with open('company2cluster.csv', 'w') as fout:
for i in range(num_companies):
fout.write('%d,%d\n' % (i + offset, random.randint(1, 20)))
def is_sane(row):
"""try to identify malformed lines. The labeled.csv contains a lot of garbage"""
try:
int(row[0])
except ValueError:
return False
return len(row) == 8
def _gen_golden_company_industry(name):
"create a CSV file containing only the columns of companyID , and industryID (col ID 0 and 6)"
from io import StringIO
output = StringIO()
with open(name, 'r', newline='') as fin:
reader = csv.reader(fin, quoting=csv.QUOTE_ALL)
next(reader) # skip header line
i = 0
while True:
try:
for row in reader:
i = i + 1
try:
if is_sane(row):
output.write('%s,"%s"\n' % (row[0], row[6]))
else:
print("line %d is malformed. skipping: %s" % (i, row))
except IndexError:
print("line %d caused index error" % i)
break
except csv.Error:
print("skipping line %d with error in it" % i)
with open('golden.csv', 'w') as ofile:
ofile.write(output.getvalue())
if __name__ == "__main__":
# _gen_industry2cluster(sys.argv[3])
# _gen_company2cluster(200000)
#_gen_golden_company_industry("/home/cnoam/courses_teaching/94290/labeled.csv")
# usage: check_hw2 industry_cluster.csv company2cluster.csv true_labels_starting_at_150k.csv
score = check(sys.argv[1], sys.argv[2], sys.argv[3])
# any error above will throw, causing exit with error, so no need to do anything
print("score=%d" % score)
exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment