Skip to content

Instantly share code, notes, and snippets.

@jaklinger
Last active August 11, 2020 13:31
Show Gist options
  • Save jaklinger/2a194867e89a05939b3390131889d568 to your computer and use it in GitHub Desktop.
Save jaklinger/2a194867e89a05939b3390131889d568 to your computer and use it in GitHub Desktop.
read arxiv vectors
from nesta.core.orms.orm_utils import db_session, get_mysql_engine
from nesta.core.orms.arxiv_orm import ArticleVector
import numpy as np
import json
import os
os.environ['MYSQLDB'] = "/path/to/innovation-mapping-5712.config"
def query_and_bundle(session, fields, start, limit, filter_):
q = session.query(*fields)
if filter_ is not None:
q = q.filter(filter_)
else:
q = q.offset(start)
ids, vectors = zip(*q.limit(limit))
return np.array(ids, dtype=np.dtype('U40')), np.array(vectors, dtype=np.float32)
def prefill_inputs():
engine = get_mysql_engine("MYSQLDB", "mysqldb", "production")
with db_session(engine) as session:
count = session.query(ArticleVector).count()
a_vector, = session.query(ArticleVector.vector).limit(1).one()
dim = len(a_vector)
data = np.empty((count, dim), dtype=np.float32)
ids = np.empty((count, ), dtype=np.dtype('U40'))
return data, ids
def read_data(data, ids, chunksize=10000, start=None, max_chunks=None):
engine = get_mysql_engine("MYSQLDB", "mysqldb", "production")
fields = (ArticleVector.article_id, ArticleVector.vector)
count, _ = data.shape
start = sum(ids != '') if start is None else start # resume or take given value
filter_ = None
n_chunks = 0
while start < count:
if max_chunks is not None and max_chunks >= n_chunks:
break
if start % 100000 == 0:
print("Collecting row", start)
limit = chunksize if start + chunksize < count else None
with db_session(engine) as session:
_ids, _data = query_and_bundle(session, fields, start, limit, filter_)
filter_ = ArticleVector.article_id > _ids[-1]
ids[start:start+_ids.shape[0]] = _ids
data[start:start+_data.shape[0]] = _data
start += chunksize
n_chunks += 1
if __name__ == "__main__":
data, ids = prefill_inputs() # empty numpy arrays
while "reading data":
try:
n = sum(ids != '') # number of collected docs since the connection broke
if n > 0:
print("restarting from", n)
read_data(data, ids) # start or continue reading
except json.JSONDecodeError: # Happens if your connection drops slightly, corrupting the JSON
continue # retry
else:
break # done
np.save('arxiv_vectors.npy', data)
np.save('arxiv_vectors_ids.npy', ids)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment