Skip to content

Instantly share code, notes, and snippets.

@cuter44
Created June 7, 2021 13:20
Show Gist options
  • Save cuter44/8abc91dbc4d6ae1f01a23e2f43d736cd to your computer and use it in GitHub Desktop.
Save cuter44/8abc91dbc4d6ae1f01a23e2f43d736cd to your computer and use it in GitHub Desktop.
Embedding conversion for NNQLM
#! python3
"""
USAGE
%0 n_voc n_dim INPUT_FILE OUTPUT_FILE
convert embedding from text format to google c format, as this program needed
for input.
n_voc, n_dim : Dimension of embedding table, neglected if these two values
presented in INPUT_FILE
INPUT_FILE : Embedding table in text format. May or may not contains a row
of two int, indicating n_voc and n_dim. Following with n_voc
rows, each row consists of a word and n_dim floats,
seperated by white space char.
OUTPUT_FILE : Default to stdout
"""
import sys
import numpy
if __name__ == "__main__":
n_voc = int(sys.argv[1])
n_dim = int(sys.argv[2])
fn_in = sys.argv[3]
fn_out = sys.argv[4] if len(sys.argv)>4 else None
fin = open(fn_in, "r", encoding="utf-8")
fout = open(fn_out, "wb") if fn_out else sys.stdout
try:
line = fin.readline().strip()
i_voc, i_dim = [int(e) for e in line.split()]
fout.write(
(str(i_voc)+' '+str(i_dim)+'\n').encode()
)
n_voc, n_dim = i_voc, i_dim
except ValueError as e:
# if no header
#print(e)
fin.close()
fin = open(fn_in, "r", encoding="utf-8")
fout.write(
(str(n_voc)+' '+str(n_dim)+'\n').encode()
)
# end header
cnt_line = 0
for line in fin:
cnt_line += 1
crumbs = line.split()
fout.write(crumbs[0].encode())
fout.write(' '.encode())
v = [float(e) for e in crumbs[1:]]
v = numpy.array(v, dtype=numpy.float32)
b = v.tobytes()
fout.write(b)
fout.write('\n'.encode())
if v.shape[0] != n_dim:
print("(!) inconsistent vector dimension", cnt_line, crumbs[0])
# end loop content
fin.close()
if fout!=sys.stdout:
fout.close()
if cnt_line!=n_voc:
print("(!) inconsistent vocabulary size", cnt_line)
# end main
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment