Skip to content

Instantly share code, notes, and snippets.

Created November 28, 2021 10:12
Show Gist options
  • Save noam1023/5d8ffd9b235010e263d49604f48fcb27 to your computer and use it in GitHub Desktop.
Save noam1023/5d8ffd9b235010e263d49604f48fcb27 to your computer and use it in GitHub Desktop.
for review - Almog
# check the results of DSLAB HW2
# expected input:
# industry2cluster_123456789_987654312.csv company2cluster_123456789_987654312.csv golden.csv
import random
import sys
import csv
def read_industries(fn) -> dict:
"""read the mapping industry -> cluster"""
r = {}
with open(fn, 'r') as input:
reader = csv.reader(input, quoting=csv.QUOTE_ALL)
for row in reader:
industry, id = row
id = int(id.strip())
r[industry] = id
return r
def read_golden_industry(fn) -> dict:
read the CSV file containing <int companyID> -> <string (with potential coma) industry name>
:param fn: CSV file name
:return: dict{ companyId -> industry name)
r = {}
with open(fn, 'r', newline='') as input:
reader = csv.reader(input, quoting=csv.QUOTE_ALL)
for row in reader:
companyId, industry = row
r[companyId] = industry
return r
def check(industry2cluster_fn, company2_cluster_fn, golden_company_industry_fn):
Given the data files from the user, compare the results using the true industry of each company.
:return: number of correct (relative to the golden values of industry) of mapping a company to industry to the SAME cluster.
sum = 0
golden_company_industry = read_golden_industry(golden_company_industry_fn)
industry2cluster = read_industries(industry2cluster_fn)
pred_company2cluster = {}
with open(company2_cluster_fn, 'r') as company_csv:
for row in company_csv:
comp_id, cluster_id = row.split(",")
cluster_id = int(cluster_id.strip())
pred_company2cluster[comp_id] = cluster_id
true_industry = golden_company_industry[comp_id]
cluster_from_industry = industry2cluster[true_industry]
sum += cluster_from_industry == cluster_id
except KeyError:
print("company ID %d not found" % int(comp_id))
return sum
def check_cluster_sizes(pred_company2cluster: dict):
cluster_size = [0]*21
for v in pred_company2cluster.values():
cluster_size[v] +=1
print("Cluster sizes:", cluster_size)
# to get rough idea how the cluster sizes are distributed
import numpy as np
mean = np.mean(cluster_size)
std = np.std(cluster_size)
if std/mean > 0.5 :
print("cluster stats: mean=%d std=%d" % (mean,std))
def _gen_industry2cluster(golden_company_industry_fn):
""" generate a fake result industry (string) -> cluster (int [1..20])"""
golden_company_industry = read_golden_industry(golden_company_industry_fn)
industries = set(golden_company_industry.values())
print("found %d industries" % len(industries))
with open('industry2cluster.csv', 'w') as fout:
for name in industries:
fout.write('"%s",%d\n' % (name, random.randint(1, 20)))
def _gen_company2cluster(num_companies):
""" create table of company ID -> cluster ID"""
import random
offset = 1500000
with open('company2cluster.csv', 'w') as fout:
for i in range(num_companies):
fout.write('%d,%d\n' % (i + offset, random.randint(1, 20)))
def is_sane(row):
"""try to identify malformed lines. The labeled.csv contains a lot of garbage"""
except ValueError:
return False
return len(row) == 8
def _gen_golden_company_industry(name):
"create a CSV file containing only the columns of companyID , and industryID (col ID 0 and 6)"
# instead of writing to the file, keep the whole data in memory, saving writes
from io import StringIO
output = StringIO()
with open(name, 'r', newline='') as fin:
reader = csv.reader(fin, quoting=csv.QUOTE_ALL)
next(reader) # skip header line
i = 0
while True:
for row in reader:
i = i + 1
if is_sane(row):
output.write('%s,%d\n' % (row[0], industry_to_id(row[6])))
print("line %d is malformed. skipping" % i)
except IndexError:
print("line %d caused index error" % i)
except csv.Error:
print("skipping line %d with error in it" % i)
with open('golden.csv', 'w') as ofile:
if __name__ == "__main__":
# _gen_industry2cluster(sys.argv[3])
# _gen_company2cluster(200000)
# _gen_golden_company_industry("/home/cnoam/courses_teaching/94290/labeled.csv")
# exit(0)
# usage: check_hw2 industry_cluster.csv company2cluster.csv golden_company_industry.csv
score = check(sys.argv[1], sys.argv[2], sys.argv[3])
# any error above will throw, causing exit with error, so no need to do anything
print("score=%d" % score)
Copy link

the _gen* functions were used to create test data. you can ignore them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment