Last active
December 17, 2020 20:21
-
-
Save mikemoritz/a5bf76193ccb16d018a1af9ec584fb41 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
This script pulls apart the logic in argostranslate.translate.apply_packaged_translation and prints details. | |
The model is installed into a temporary directory, and the from/to lang is assumed based on the .argosmodel provided. | |
Details on the ctranslate2 API is available here: https://github.com/OpenNMT/CTranslate2/blob/master/docs/python.md | |
Note that this bypasses argostranslate paragraph splitting, so this is intended for short test strings. | |
''' | |
import os | |
import sys | |
import pathlib | |
import argparse | |
import tempfile | |
# wrap import with ctranslate2 verbose logging to get backend details | |
envar = 'CT2_VERBOSE' | |
ctv_old = os.environ.get(envar) | |
os.environ[envar] = "1" | |
from argostranslate import package, translate, settings | |
if ctv_old is None: | |
os.environ.pop(envar) | |
else: | |
os.environ[envar] = ctv_old | |
import ctranslate2 | |
import sentencepiece as spm | |
import stanza | |
class ArgosDebug: | |
def __init__(self, text, modelpath, translate_opts): | |
self.text = text | |
self.modelpath = modelpath | |
self.translate_opts = translate_opts | |
# directories | |
self.tmpdir = None | |
self.pkgdir = None | |
# translation objs | |
self.langs = None | |
self.pkg = None | |
self.translator = None | |
# parse modelpath | |
if not os.path.isfile(modelpath): | |
raise Exception(f'modelpath={modelpath} does not exist') | |
if not modelpath.endswith('.argosmodel'): | |
raise Exception(f'modelpath={modelpath} is not a valid .argosmodel file') | |
self.pkgname = modelpath.split('/')[-1].split('.')[0] | |
self.fromcode, self.tocode = self.pkgname.split('_') | |
print(f'translating {self.fromcode} -> {self.tocode}') | |
def run(self): | |
with tempfile.TemporaryDirectory() as tmpdir: | |
self.tmpdir = tmpdir | |
self.pkgdir = os.path.join(tmpdir, 'packages') | |
self.install() | |
self.translate() | |
self.translate_debug() | |
def install(self): | |
settings.data_dir = pathlib.Path(self.tmpdir) | |
settings.package_data_dir = pathlib.Path(self.pkgdir) | |
settings.package_dirs = [settings.package_data_dir] | |
package.install_from_path(self.modelpath) | |
pkgs = package.get_installed_packages() | |
self.pkg = pkgs[0] | |
self.langs = {lang.code: lang for lang in translate.load_installed_languages()} | |
self.translator = ctranslate2.Translator(os.path.join(self.pkg.package_path, 'model')) | |
def translate(self): | |
print(f'running native translation') | |
translator = self.langs[self.fromcode].get_translation(self.langs[self.tocode]) | |
outstr = translator.translate(self.text) | |
print(f'native translation result:\n\t"{outstr}"') | |
def translate_debug(self): | |
''' | |
Copies translate.apply_packaged_translation with additional output | |
''' | |
print(f'starting debug translation') | |
pkg = self.pkg | |
input_text = self.text | |
translator = self.translator | |
sp_model_path = str(pkg.package_path / 'sentencepiece.model') | |
sp_processor = spm.SentencePieceProcessor(model_file=sp_model_path) | |
stanza_pipeline = stanza.Pipeline(lang=pkg.from_code, | |
dir=str(pkg.package_path / 'stanza'), | |
processors='tokenize', use_gpu=False, | |
logging_level='WARNING') | |
stanza_sbd = stanza_pipeline(input_text) | |
sentences = [sentence.text for sentence in stanza_sbd.sentences] | |
print(f'sentences:') | |
for sentence in sentences: | |
print(f'\t"{sentence}"') | |
sentence_pieces = [sp_processor.encode(sentence, out_type=str) for sentence in sentences] | |
print(f'sentence pieces:') | |
for piece in sentence_pieces: | |
print(f'\t{piece}') | |
translated_batches = translator.translate_batch(sentence_pieces, **self.translate_opts) | |
for i in range(self.translate_opts['num_hypotheses']): | |
tokens = [] | |
print(f'hypotheses={i}') | |
for j, batch in enumerate(translated_batches): | |
print(f'\tbatch={j}:') | |
print(f'\t\traw: {batch[i]}') | |
print(f'\t\ttokens: {batch[i]["tokens"]}') | |
tokens += batch[i]['tokens'] | |
result = self.translate_tokens(tokens) | |
print(f'\tdebug translation:\n\t\t"{result}"') | |
@staticmethod | |
def translate_tokens(tokens): | |
detokenized = ''.join(tokens) | |
detokenized = detokenized.replace('▁', ' ') | |
to_return = detokenized | |
if len(to_return) > 0 and to_return[0] == ' ': | |
# Remove space at the beginning of the translation added | |
# by the tokenizer. | |
to_return = to_return[1:] | |
return to_return | |
def main(argv): | |
pars = argparse.ArgumentParser('debug_translation', description='Provide debug details on an individual translation') | |
pars.add_argument('text', help='Text to translate') | |
pars.add_argument('modelpath', help='Full path to *.argosmodel') | |
pars.add_argument('--beam-size', default=2, type=int, help='If 1 uses greedy search, if >=2 uses beam search') | |
pars.add_argument('--max-batch-size', default=32, type=int, help='Default 32 is used by argostranslate') | |
pars.add_argument('--num-hypotheses', default=1, type=int, help='Number of hypotheses to return, must be <= beam_size') | |
opts = pars.parse_args(argv) | |
opts = vars(opts) | |
text = opts.pop('text') | |
modelpath = opts.pop('modelpath') | |
translate_opts = opts | |
argosdebug = ArgosDebug(text, modelpath, translate_opts) | |
argosdebug.run() | |
if __name__ == '__main__': | |
sys.exit(main(sys.argv[1:])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment