Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ayushidalmia/d74cf0e7545b860d337914a5c05f7177 to your computer and use it in GitHub Desktop.
Save ayushidalmia/d74cf0e7545b860d337914a5c05f7177 to your computer and use it in GitHub Desktop.
from onmt.translate.Translator import Translator
from onmt.translate.Translation import TranslationBuilder
import onmt
import onmt.ModelConstructor
from utils import *
from parameters import *
import opts
import argparse
import torch
from onmt.io import IO
class MyTranslator():
def __init__(self,model,n_best=5, max_length=50):
""" Inititalization.
"""
parser = argparse.ArgumentParser(description='translate.py') #,formatter_class=argparse.ArgumentDefaultsHelpFormatter)
#opts.add_md_help_argument(parser)
opts.translate_opts(parser)
self.opt = parser.parse_args()
self.opt.model = model
self.opt.n_best = n_best
self.opt.max_length = max_length
train_parser = argparse.ArgumentParser(description='train.py')
opts.model_opts(train_parser)
train_opt = train_parser.parse_known_args([])[0]
self.opt.cuda = self.opt.gpu > -1
if self.opt.cuda:
torch.cuda.set_device(opts.gpu)
fields, model, model_opt = onmt.ModelConstructor.load_test_model(self.opt, train_opt.__dict__)
self.generator = Translator(model, fields, beam_size = self.opt.beam_size, n_best= self.opt.n_best, max_length=self.opt.max_length, global_scorer=None, copy_attn=False, cuda=False, beam_trace=False, min_length=0)
self.fields = fields
def get_translation(self, text):
pass
if __name__ == '__main__':
tn = MyTranslator(model = "demo-model_acc_64.75_ppl_4.74_e14.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment