-
-
Save Q726kbXuN/97a88e3d4e6101811fb8bd554d3a34df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from datetime import datetime | |
import itertools, json, os, re, sqlite3, subprocess | |
DB_NAME = "tag_connections.db" | |
def increment_value(db, key, dest, select_sql): | |
if key not in dest: | |
cur = db.execute(select_sql, key).fetchone() | |
if cur is None: | |
dest[key] = [True, 1] | |
else: | |
dest[key] = [False, cur[0] + 1] | |
else: | |
dest[key][1] += 1 | |
def dump_values(db, dest, insert_sql, update_sql): | |
inserts = [k + (v[1],) for k, v in dest.items() if v[0]] | |
if len(inserts): | |
db.executemany(insert_sql, inserts) | |
updates = [(v[1],) + k for k, v in dest.items() if not v[0]] | |
if len(inserts): | |
db.executemany(update_sql, updates) | |
def get_rows(cmd): | |
cmd = cmd.split(" ") | |
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, universal_newlines=True, encoding="utf-8") | |
for row in p.stdout: | |
yield row | |
def create_db(): | |
for row in get_rows("7z l Stackoverflow.com-Posts.7z"): | |
row = row.strip().split(' ') | |
if row[-1] == "Posts.xml": | |
expected_size = int(row[4]) | |
r = re.compile('Tags="([^"]+?)"') | |
if os.path.isfile(DB_NAME): | |
os.unlink(DB_NAME) | |
db = sqlite3.connect(DB_NAME) | |
db.execute("CREATE TABLE tag(x TEXT NOT NULL, count INT NOT NULL);") | |
db.execute("CREATE UNIQUE INDEX tag_index_x ON tag(x);") | |
db.execute("CREATE TABLE pair(x TEXT NOT NULL, y TEXT NOT NULL, count INT NOT NULL);") | |
db.execute("CREATE INDEX pair_index_x ON pair(x);") | |
db.execute("CREATE UNIQUE INDEX pair_index_xy ON pair(x, y);") | |
db.execute("CREATE TABLE info(count);") | |
db.execute("INSERT INTO info (count) VALUES (0);") | |
db.commit() | |
db.execute('PRAGMA synchronous = OFF;') | |
db.execute('PRAGMA journal_mode = OFF;') | |
db.execute('PRAGMA locking_mode = EXCLUSIVE;') | |
db.execute('PRAGMA secure_delete = OFF;') | |
total_parsed = 0 | |
total_pairs, total_tags = {}, {} | |
cur_pos = 0 | |
dump_at = 1 | |
last_dump = datetime.utcnow() | |
for row in get_rows('7z -so x Stackoverflow.com-Posts.7z Posts.xml'): | |
cur_pos += len(row) | |
if row.startswith(' <row') and 'PostTypeId="1"' in row and 'Tags="' in row: | |
if m := r.search(row): | |
tags = m.group(1) | |
if tags.startswith("<") and tags.endswith(">"): | |
tags = tags[4:-4].split("><") | |
if len(tags) >= 2: | |
for tag in tags: | |
increment_value(db, (tag,), total_tags, "SELECT count FROM tag WHERE x = ?;") | |
for x, y in itertools.combinations(tags, 2): | |
increment_value(db, (x, y), total_pairs, "SELECT count FROM pair WHERE x = ? AND y = ?;") | |
increment_value(db, (y, x), total_pairs, "SELECT count FROM pair WHERE x = ? AND y = ?;") | |
total_parsed += 1 | |
if total_parsed % dump_at == 0: | |
dump_at = min(262144, dump_at * 2) | |
now = datetime.utcnow() | |
secs = (now - last_dump).total_seconds() | |
print(f"Working, parsed {total_parsed:,}, at {cur_pos / expected_size * 100:.2f}%, took {secs:.1f} seconds...") | |
last_dump = now | |
dump_values(db, total_pairs, "INSERT INTO pair(x, y, count) VALUES(?,?,?);", "UPDATE pair SET count=? WHERE x=? AND y=?;") | |
dump_values(db, total_tags, "INSERT INTO tag(x, count) VALUES(?,?);", "UPDATE tag SET count=? WHERE x=?;") | |
db.execute("UPDATE info SET count=?;", (total_parsed,)) | |
db.commit() | |
total_pairs, total_tags = {}, {} | |
dump_values(db, total_pairs, "INSERT INTO pair(x, y, count) VALUES(?,?,?);", "UPDATE pair SET count=? WHERE x=? AND y=?;") | |
dump_values(db, total_tags, "INSERT INTO tag(x, count) VALUES(?,?);", "UPDATE tag SET count=? WHERE x=?;") | |
db.execute("UPDATE info SET count=?;", (total_parsed,)) | |
db.commit() | |
db.close() | |
def dump_stats(): | |
base_tag_limit = 1000 | |
pair_tag_limit = 25 | |
db = sqlite3.connect(DB_NAME) | |
with open("tag_connections.jsonl", "wt") as f: | |
def dump_row(value): | |
value = json.dumps(value) | |
print(value) | |
f.write(value + "\n") | |
count = db.execute("SELECT count FROM info;").fetchone()[0] | |
dump_row([f"# From a total of {count:,} answers"]) | |
dump_row([f"# [tag, answers, [connected_tag, percent_connected], ...]"]) | |
dump_row([""]) | |
for base_tag, base_count in db.execute(f"SELECT x, count FROM tag ORDER BY count DESC LIMIT {base_tag_limit};"): | |
row = [base_tag, base_count] | |
for other_pair, other_count in db.execute(f"SELECT y, count FROM pair WHERE x = ? ORDER BY count DESC LIMIT {pair_tag_limit}", (base_tag,)): | |
perc = other_count / base_count * 100 | |
if perc < 1: | |
break | |
row.append([other_pair, int(perc * 100) / 100]) | |
dump_row(row) | |
def main(): | |
if not os.path.isfile(DB_NAME): | |
create_db() | |
dump_stats() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment