Skip to content

Instantly share code, notes, and snippets.

@soaxelbrooke
Last active March 1, 2023 09:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save soaxelbrooke/f75ff2a6f8a432afaf72990bcfe6f086 to your computer and use it in GitHub Desktop.
Save soaxelbrooke/f75ff2a6f8a432afaf72990bcfe6f086 to your computer and use it in GitHub Desktop.
Script for converting txt word embedding files to SQLite databases for fast embedding lookup.
#!/usr/bin/env python3.6
"""
Example usage:
$ python3.6 wvsqlite.py glove.840B.300d.txt
Produces an sqlite database at with byte strings of floats for each word vector, indexed by
token for fast lookup for vocabs much smaller than the embedding vocab (aka most real vocabs).
Float size can be set via FLOAT_BYTES env var, and can be 4 or 8, and LIMIT can be set to take
the top N word vectors.
Metadata is also saved in the `vector_meta` table.
"""
import pandas
import csv
import sqlite3
from tqdm import tqdm
import sys
import os
def guess_embed_dim(embeddings_path: str) -> int:
with open(embeddings_path) as infile:
return max([len(next(infile).split(' ')), len(next(infile).split(' '))])
def load_wvs(embeddings_path: str, embedding_dim: int, limit=None):
if limit is not None:
limit = int(limit)
with open(embeddings_path) as infile:
if next(infile).split(' ') == embeddings_path:
# Skip header for fasttext, don't for glove
infile.seek(0)
return pandas.read_csv(infile, header=None, delim_whitespace=True,
names=list(range(embedding_dim)), quoting=csv.QUOTE_NONE,
nrows=limit, index_col=0)
def insert_wvs(wvs: pandas.DataFrame, embedding_dim: int, float_bytes: int):
assert float_bytes == 4 or float_bytes == 8
try:
os.remove('vectors.sqlite')
except:
pass
conn = sqlite3.connect('vectors.sqlite')
conn.execute('''
CREATE TABLE vector_meta (
vector_float_bytes integer,
embedding_dimensions integer,
vocab_size integer
)
''')
conn.execute('INSERT INTO vector_meta VALUES (?, ?, ?)',
(float_bytes, embedding_dim, wvs.shape[0]))
conn.execute('CREATE TABLE vectors (token text primary key, vector_bytes blob);')
seen = set()
for token, series in tqdm(wvs.iterrows(), total=wvs.shape[0]):
if token in seen:
continue
float_type = 'float32' if float_bytes == 4 else 'float64'
vec_bytes = series.values.astype(float_type).tobytes()
conn.execute('INSERT INTO vectors VALUES (?, ?)', (token, vec_bytes))
seen.add(token)
conn.commit()
if __name__ == '__main__':
embed_path = sys.argv[1]
embed_dim = guess_embed_dim(embed_path)
float_bytes = int(os.getenv('FLOAT_BYTES', 4))
print('Loading word vectors...')
wvs = load_wvs(embed_path, embed_dim, os.getenv('LIMIT'))
print('Saving word vector bytes to vectors.sqlite...')
insert_wvs(wvs, embed_dim, float_bytes)
print('Done!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment