for review - Almog
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# check the results of DSLAB HW2 | |
# expected input: | |
# check_hw2.py industry2cluster_123456789_987654312.csv company2cluster_123456789_987654312.csv golden.csv | |
import random | |
import sys | |
import csv | |
def read_industries(fn) -> dict: | |
"""read the mapping industry -> cluster""" | |
r = {} | |
with open(fn, 'r') as input: | |
reader = csv.reader(input, quoting=csv.QUOTE_ALL) | |
for row in reader: | |
industry, id = row | |
id = int(id.strip()) | |
r[industry] = id | |
return r | |
def read_golden_industry(fn) -> dict: | |
""" | |
read the CSV file containing <int companyID> -> <string (with potential coma) industry name> | |
:param fn: CSV file name | |
:return: dict{ companyId -> industry name) | |
""" | |
r = {} | |
with open(fn, 'r', newline='') as input: | |
reader = csv.reader(input, quoting=csv.QUOTE_ALL) | |
for row in reader: | |
companyId, industry = row | |
r[companyId] = industry | |
return r | |
def check(industry2cluster_fn, company2_cluster_fn, golden_company_industry_fn): | |
""" | |
Given the data files from the user, compare the results using the true industry of each company. | |
:return: number of correct (relative to the golden values of industry) of mapping a company to industry to the SAME cluster. | |
""" | |
sum = 0 | |
golden_company_industry = read_golden_industry(golden_company_industry_fn) | |
industry2cluster = read_industries(industry2cluster_fn) | |
pred_company2cluster = {} | |
with open(company2_cluster_fn, 'r') as company_csv: | |
for row in company_csv: | |
comp_id, cluster_id = row.split(",") | |
cluster_id = int(cluster_id.strip()) | |
pred_company2cluster[comp_id] = cluster_id | |
try: | |
true_industry = golden_company_industry[comp_id] | |
cluster_from_industry = industry2cluster[true_industry] | |
sum += cluster_from_industry == cluster_id | |
except KeyError: | |
print("company ID %d not found" % int(comp_id)) | |
check_cluster_sizes(pred_company2cluster) | |
return sum | |
def check_cluster_sizes(pred_company2cluster: dict): | |
cluster_size = [0]*21 | |
for v in pred_company2cluster.values(): | |
cluster_size[v] +=1 | |
print("Cluster sizes:", cluster_size) | |
# to get rough idea how the cluster sizes are distributed | |
import numpy as np | |
mean = np.mean(cluster_size) | |
std = np.std(cluster_size) | |
if std/mean > 0.5 : | |
print("cluster stats: mean=%d std=%d" % (mean,std)) | |
def _gen_industry2cluster(golden_company_industry_fn): | |
""" generate a fake result industry (string) -> cluster (int [1..20])""" | |
golden_company_industry = read_golden_industry(golden_company_industry_fn) | |
industries = set(golden_company_industry.values()) | |
print("found %d industries" % len(industries)) | |
with open('industry2cluster.csv', 'w') as fout: | |
for name in industries: | |
fout.write('"%s",%d\n' % (name, random.randint(1, 20))) | |
def _gen_company2cluster(num_companies): | |
""" create table of company ID -> cluster ID""" | |
import random | |
offset = 1500000 | |
with open('company2cluster.csv', 'w') as fout: | |
for i in range(num_companies): | |
fout.write('%d,%d\n' % (i + offset, random.randint(1, 20))) | |
def is_sane(row): | |
"""try to identify malformed lines. The labeled.csv contains a lot of garbage""" | |
try: | |
int(row[0]) | |
except ValueError: | |
return False | |
return len(row) == 8 | |
def _gen_golden_company_industry(name): | |
"create a CSV file containing only the columns of companyID , and industryID (col ID 0 and 6)" | |
# instead of writing to the file, keep the whole data in memory, saving writes | |
from io import StringIO | |
output = StringIO() | |
with open(name, 'r', newline='') as fin: | |
reader = csv.reader(fin, quoting=csv.QUOTE_ALL) | |
next(reader) # skip header line | |
i = 0 | |
while True: | |
try: | |
for row in reader: | |
i = i + 1 | |
try: | |
if is_sane(row): | |
output.write('%s,%d\n' % (row[0], industry_to_id(row[6]))) | |
else: | |
print("line %d is malformed. skipping" % i) | |
except IndexError: | |
print("line %d caused index error" % i) | |
break | |
except csv.Error: | |
print("skipping line %d with error in it" % i) | |
with open('golden.csv', 'w') as ofile: | |
ofile.write(output.getvalue()) | |
if __name__ == "__main__": | |
# _gen_industry2cluster(sys.argv[3]) | |
# _gen_company2cluster(200000) | |
# _gen_golden_company_industry("/home/cnoam/courses_teaching/94290/labeled.csv") | |
# exit(0) | |
# usage: check_hw2 industry_cluster.csv company2cluster.csv golden_company_industry.csv | |
score = check(sys.argv[1], sys.argv[2], sys.argv[3]) | |
# any error above will throw, causing exit with error, so no need to do anything | |
print("score=%d" % score) | |
exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
the _gen* functions were used to create test data. you can ignore them