Skip to content

Instantly share code, notes, and snippets.

@fancyerii
Created October 31, 2020 01:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fancyerii/feb49e0191f5fc36802f424a91f9d475 to your computer and use it in GitHub Desktop.
Save fancyerii/feb49e0191f5fc36802f424a91f9d475 to your computer and use it in GitHub Desktop.
predict.py
import tensorflow as tf
from tensor2tensor import models
from tensor2tensor import problems
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import registry
from tensor2tensor.utils.trainer_lib import create_hparams
import os
import numpy as np
PROBLEM = 'translate_enzh_wmt8k' # registry.list_models() # Show all registered models
MODEL = 'transformer' # Hyperparameters for the model by default
# start with "transformer_base" or 'transformer_base_single_gpu'
# if training on a single GPU
HPARAMS = 'transformer_base_single_gpu'
DATA_DIR = os.path.expanduser("~/t2tcn2/data")
TRAIN_DIR = os.path.expanduser("~/t2tcn2/train")
t2t_problem = problems.problem(PROBLEM)
Modes = tf.estimator.ModeKeys
hparams = create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)
translate_model = registry.model(MODEL)(hparams, Modes.PREDICT)
# Get the encoders (fixed pre-processing) from the problem
encoders = t2t_problem.feature_encoders(DATA_DIR)
def encode(input_str, output_str=None):
"""Input str to features dict, ready for inference"""
inputs = encoders["inputs"].encode(input_str) + [1] # add EOS
batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D
return {"inputs": batch_inputs}
def decode(integers):
"""List of ints to str"""
integers = list(np.squeeze(integers))
if 1 in integers:
integers = integers[:integers.index(1)]
return encoders["targets"].decode(np.squeeze(integers))
# Get the latest checkpoint
ckpt_path = tf.train.latest_checkpoint(TRAIN_DIR)
print('Latest Checkpoint: ', ckpt_path)
def translate(inputs):
encoded_inputs = encode(inputs)
with tf.compat.v1.restore_variables_on_create(ckpt_path):
model_output = translate_model.infer(encoded_inputs)["outputs"]
return decode(model_output)
inputs = ["I think they will never come back to the US.",
"Human rights is the first priority.",
"Everyone should have health insurance.",
"President Trump's overall approval rating dropped 7% over the past month"]
for sentence in inputs:
output = translate(sentence)
print("\33[34m Inputs:\33[30m %s" % sentence)
print("\033[35m Outputs:\33[30m %s" % output)
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment