Skip to content

Instantly share code, notes, and snippets.

@mikemoritz
Last active December 17, 2020 20:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mikemoritz/a5bf76193ccb16d018a1af9ec584fb41 to your computer and use it in GitHub Desktop.
Save mikemoritz/a5bf76193ccb16d018a1af9ec584fb41 to your computer and use it in GitHub Desktop.
'''
This script pulls apart the logic in argostranslate.translate.apply_packaged_translation and prints details.
The model is installed into a temporary directory, and the from/to lang is assumed based on the .argosmodel provided.
Details on the ctranslate2 API is available here: https://github.com/OpenNMT/CTranslate2/blob/master/docs/python.md
Note that this bypasses argostranslate paragraph splitting, so this is intended for short test strings.
'''
import os
import sys
import pathlib
import argparse
import tempfile
# wrap import with ctranslate2 verbose logging to get backend details
envar = 'CT2_VERBOSE'
ctv_old = os.environ.get(envar)
os.environ[envar] = "1"
from argostranslate import package, translate, settings
if ctv_old is None:
os.environ.pop(envar)
else:
os.environ[envar] = ctv_old
import ctranslate2
import sentencepiece as spm
import stanza
class ArgosDebug:
def __init__(self, text, modelpath, translate_opts):
self.text = text
self.modelpath = modelpath
self.translate_opts = translate_opts
# directories
self.tmpdir = None
self.pkgdir = None
# translation objs
self.langs = None
self.pkg = None
self.translator = None
# parse modelpath
if not os.path.isfile(modelpath):
raise Exception(f'modelpath={modelpath} does not exist')
if not modelpath.endswith('.argosmodel'):
raise Exception(f'modelpath={modelpath} is not a valid .argosmodel file')
self.pkgname = modelpath.split('/')[-1].split('.')[0]
self.fromcode, self.tocode = self.pkgname.split('_')
print(f'translating {self.fromcode} -> {self.tocode}')
def run(self):
with tempfile.TemporaryDirectory() as tmpdir:
self.tmpdir = tmpdir
self.pkgdir = os.path.join(tmpdir, 'packages')
self.install()
self.translate()
self.translate_debug()
def install(self):
settings.data_dir = pathlib.Path(self.tmpdir)
settings.package_data_dir = pathlib.Path(self.pkgdir)
settings.package_dirs = [settings.package_data_dir]
package.install_from_path(self.modelpath)
pkgs = package.get_installed_packages()
self.pkg = pkgs[0]
self.langs = {lang.code: lang for lang in translate.load_installed_languages()}
self.translator = ctranslate2.Translator(os.path.join(self.pkg.package_path, 'model'))
def translate(self):
print(f'running native translation')
translator = self.langs[self.fromcode].get_translation(self.langs[self.tocode])
outstr = translator.translate(self.text)
print(f'native translation result:\n\t"{outstr}"')
def translate_debug(self):
'''
Copies translate.apply_packaged_translation with additional output
'''
print(f'starting debug translation')
pkg = self.pkg
input_text = self.text
translator = self.translator
sp_model_path = str(pkg.package_path / 'sentencepiece.model')
sp_processor = spm.SentencePieceProcessor(model_file=sp_model_path)
stanza_pipeline = stanza.Pipeline(lang=pkg.from_code,
dir=str(pkg.package_path / 'stanza'),
processors='tokenize', use_gpu=False,
logging_level='WARNING')
stanza_sbd = stanza_pipeline(input_text)
sentences = [sentence.text for sentence in stanza_sbd.sentences]
print(f'sentences:')
for sentence in sentences:
print(f'\t"{sentence}"')
sentence_pieces = [sp_processor.encode(sentence, out_type=str) for sentence in sentences]
print(f'sentence pieces:')
for piece in sentence_pieces:
print(f'\t{piece}')
translated_batches = translator.translate_batch(sentence_pieces, **self.translate_opts)
for i in range(self.translate_opts['num_hypotheses']):
tokens = []
print(f'hypotheses={i}')
for j, batch in enumerate(translated_batches):
print(f'\tbatch={j}:')
print(f'\t\traw: {batch[i]}')
print(f'\t\ttokens: {batch[i]["tokens"]}')
tokens += batch[i]['tokens']
result = self.translate_tokens(tokens)
print(f'\tdebug translation:\n\t\t"{result}"')
@staticmethod
def translate_tokens(tokens):
detokenized = ''.join(tokens)
detokenized = detokenized.replace('▁', ' ')
to_return = detokenized
if len(to_return) > 0 and to_return[0] == ' ':
# Remove space at the beginning of the translation added
# by the tokenizer.
to_return = to_return[1:]
return to_return
def main(argv):
pars = argparse.ArgumentParser('debug_translation', description='Provide debug details on an individual translation')
pars.add_argument('text', help='Text to translate')
pars.add_argument('modelpath', help='Full path to *.argosmodel')
pars.add_argument('--beam-size', default=2, type=int, help='If 1 uses greedy search, if >=2 uses beam search')
pars.add_argument('--max-batch-size', default=32, type=int, help='Default 32 is used by argostranslate')
pars.add_argument('--num-hypotheses', default=1, type=int, help='Number of hypotheses to return, must be <= beam_size')
opts = pars.parse_args(argv)
opts = vars(opts)
text = opts.pop('text')
modelpath = opts.pop('modelpath')
translate_opts = opts
argosdebug = ArgosDebug(text, modelpath, translate_opts)
argosdebug.run()
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment