Created
February 13, 2018 10:16
-
-
Save martinpopel/6d1aed29a70659d33562ecbfa62e05fd to your computer and use it in GitHub Desktop.
a script to compute the number of subwords in a given raw bi-text, useful for estimating the number of training epochs in T2T
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from tensor2tensor.data_generators import text_encoder | |
import tensorflow as tf | |
import sys | |
flags = tf.flags | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string("vocab", None, "Path to the subword vocabulary") | |
flags.DEFINE_string("src", None, "Path to the source-language text") | |
flags.DEFINE_string("trg", None, "Path to the target-language text") | |
# TODO print the actual subwords, use vocab._subtoken_id_to_subtoken_string() instead of _subtoken_ids_to_tokens() | |
flags.DEFINE_bool("print", False, "Print a character for each subword?") | |
def eprint(*args, **kwargs): | |
print(*args, file=sys.stderr, **kwargs) | |
def words_subwords(vocab, string): | |
#subwords = vocab._subtoken_ids_to_tokens([x]) for x in vocab.encode(string)] | |
n_words = len(string.split()) | |
n_subwords = len(vocab.encode(string)) | |
return n_words, n_subwords | |
s_words, t_words, m_words = 0, 0, 0 | |
s_subws, t_subws, m_subws = 0, 0, 0 | |
sents = 0 | |
def print_stats(): | |
global s_words, t_words, m_words, s_subws, t_subws, m_subws, sents | |
eprint("\ntotal: sents=%d words=%d subwords=%s subwords/words %.4f" % (sents, m_words, m_subws, m_subws/m_words)) | |
eprint("source: words=%d subwords=%d" % (s_words, s_subws)) | |
eprint("target: words=%d subwords=%d" % (t_words, t_subws)) | |
def main(_): | |
global s_words, t_words, m_words, s_subws, t_subws, m_subws, sents | |
vocab = text_encoder.SubwordTextEncoder(FLAGS.vocab) | |
with open(FLAGS.src, encoding="utf-8") as src, open(FLAGS.trg, encoding="utf-8") as trg: | |
for s, t in zip(src, trg): | |
sents += 1 | |
s = s.strip() | |
t = t.strip() | |
s_w, s_s = words_subwords(vocab, s) | |
t_w, t_s = words_subwords(vocab, t) | |
s_words += s_w | |
t_words += t_w | |
m_words += max(s_w, t_w) | |
s_subws += s_s | |
t_subws += t_s | |
m_subws += max(s_s, t_s) | |
if sents % 100000 == 0: | |
print_stats() | |
if FLAGS.print: | |
print("a" * max(s_s, t_s)) | |
print_stats() | |
if __name__ == "__main__": | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment