Created
October 31, 2020 01:32
-
-
Save fancyerii/feb49e0191f5fc36802f424a91f9d475 to your computer and use it in GitHub Desktop.
predict.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensor2tensor import models | |
from tensor2tensor import problems | |
from tensor2tensor.utils import trainer_lib | |
from tensor2tensor.utils import t2t_model | |
from tensor2tensor.utils import registry | |
from tensor2tensor.utils.trainer_lib import create_hparams | |
import os | |
import numpy as np | |
PROBLEM = 'translate_enzh_wmt8k' # registry.list_models() # Show all registered models | |
MODEL = 'transformer' # Hyperparameters for the model by default | |
# start with "transformer_base" or 'transformer_base_single_gpu' | |
# if training on a single GPU | |
HPARAMS = 'transformer_base_single_gpu' | |
DATA_DIR = os.path.expanduser("~/t2tcn2/data") | |
TRAIN_DIR = os.path.expanduser("~/t2tcn2/train") | |
t2t_problem = problems.problem(PROBLEM) | |
Modes = tf.estimator.ModeKeys | |
hparams = create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM) | |
translate_model = registry.model(MODEL)(hparams, Modes.PREDICT) | |
# Get the encoders (fixed pre-processing) from the problem | |
encoders = t2t_problem.feature_encoders(DATA_DIR) | |
def encode(input_str, output_str=None): | |
"""Input str to features dict, ready for inference""" | |
inputs = encoders["inputs"].encode(input_str) + [1] # add EOS | |
batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D | |
return {"inputs": batch_inputs} | |
def decode(integers): | |
"""List of ints to str""" | |
integers = list(np.squeeze(integers)) | |
if 1 in integers: | |
integers = integers[:integers.index(1)] | |
return encoders["targets"].decode(np.squeeze(integers)) | |
# Get the latest checkpoint | |
ckpt_path = tf.train.latest_checkpoint(TRAIN_DIR) | |
print('Latest Checkpoint: ', ckpt_path) | |
def translate(inputs): | |
encoded_inputs = encode(inputs) | |
with tf.compat.v1.restore_variables_on_create(ckpt_path): | |
model_output = translate_model.infer(encoded_inputs)["outputs"] | |
return decode(model_output) | |
inputs = ["I think they will never come back to the US.", | |
"Human rights is the first priority.", | |
"Everyone should have health insurance.", | |
"President Trump's overall approval rating dropped 7% over the past month"] | |
for sentence in inputs: | |
output = translate(sentence) | |
print("\33[34m Inputs:\33[30m %s" % sentence) | |
print("\033[35m Outputs:\33[30m %s" % output) | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment