Skip to content

Instantly share code, notes, and snippets.

@mathigatti
Last active August 26, 2020 20:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mathigatti/92bfe1131bee2e64db9db23490b7be81 to your computer and use it in GitHub Desktop.
Save mathigatti/92bfe1131bee2e64db9db23490b7be81 to your computer and use it in GitHub Desktop.
import sys
import os
import argparse
import json
import re
import tensorflow.compat.v1 as tf
import numpy as np
from train.modeling import GroverModel, GroverConfig, sample
from tokenizers import ByteLevelBPETokenizer
# ignore tf deprecated warning temporarily
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.logging.set_verbosity(tf.logging.DEBUG)
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
try:
from tensorflow.python.util import module_wrapper as deprecation
except ImportError:
from tensorflow.python.util import deprecation_wrapper as deprecation
deprecation._PER_MODULE_WARNING_LIMIT = 0
#####
parser = argparse.ArgumentParser(description='Contextual generation (aka given some metadata we will generate articles')
parser.add_argument(
'-metadata_fn',
dest='metadata_fn',
type=str,
help='Path to a JSONL containing metadata',
)
parser.add_argument(
'-out_fn',
dest='out_fn',
type=str,
help='Out jsonl, which will contain the completed jsons',
)
parser.add_argument(
'-input',
dest='input',
type=str,
help='Text to complete',
)
parser.add_argument(
'-config_fn',
dest='config_fn',
default='../configs/base.json',
type=str,
help='Configuration JSON for the model',
)
parser.add_argument(
'-ckpt_fn',
dest='ckpt_fn',
default='../models/mega/model.ckpt',
type=str,
help='checkpoint file for the model',
)
parser.add_argument(
'-target',
dest='target',
default='article',
type=str,
help='What to generate for each item in metadata_fn. can be article (body), title, etc.',
)
parser.add_argument(
'-batch_size',
dest='batch_size',
default=1,
type=int,
help='How many things to generate per context. will split into chunks if need be',
)
parser.add_argument(
'-num_folds',
dest='num_folds',
default=1,
type=int,
help='Number of folds. useful if we want to split up a big file into multiple jobs.',
)
parser.add_argument(
'-fold',
dest='fold',
default=0,
type=int,
help='which fold we are on. useful if we want to split up a big file into multiple jobs.'
)
parser.add_argument(
'-max_batch_size',
dest='max_batch_size',
default=None,
type=int,
help='max batch size. You can leave this out and we will infer one based on the number of hidden layers',
)
parser.add_argument(
'-top_p',
dest='top_p',
default=0.95,
type=float,
help='p to use for top p sampling. if this isn\'t none, use this for everthing'
)
parser.add_argument(
'-min_len',
dest='min_len',
default=1024,
type=int,
help='min length of sample',
)
parser.add_argument(
'-eos_token',
dest='eos_token',
default=0,
type=int,
help='eos token id',
)
parser.add_argument(
'-samples',
dest='samples',
default=5,
type=int,
help='num_samples',
)
parser.add_argument(
'-vocab_file',
dest='vocab_file',
default='../vocabs/spanish/vocab.json',
type=str,
help='Tokenizer vocab.json file.'
)
parser.add_argument(
'-merges_file',
dest='merges_file',
default='../vocabs/spanish/merges.txt',
type=str,
help='Tokenizer merges.txt file.'
)
def extract_generated_target(output_tokens, tokenizer):
"""
Given some tokens that were generated, extract the target
:param output_tokens: [num_tokens] thing that was generated
:param encoder: how they were encoded
:param target: the piece of metadata we wanted to generate!
:return:
"""
# Filter out first instance of start token
assert output_tokens.ndim == 1
start_ind = 0
end_ind = output_tokens.shape[0]
return {
'extraction': tokenizer.decode(output_tokens),
'start_ind': start_ind,
'end_ind': end_ind,
}
args = parser.parse_args()
proj_root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
tokenizer = ByteLevelBPETokenizer(args.vocab_file, args.merges_file)
news_config = GroverConfig.from_json_file(args.config_fn)
# We might have to split the batch into multiple chunks if the batch size is too large
default_mbs = {12: 32, 24: 16, 48: 3}
max_batch_size = args.max_batch_size if args.max_batch_size is not None else default_mbs[news_config.num_hidden_layers]
# factorize args.batch_size = (num_chunks * batch_size_per_chunk) s.t. batch_size_per_chunk < max_batch_size
num_chunks = int(np.ceil(args.batch_size / max_batch_size))
batch_size_per_chunk = int(np.ceil(args.batch_size / num_chunks))
# This controls the top p for each generation.
top_p = np.ones((num_chunks, batch_size_per_chunk), dtype=np.float32) * args.top_p
tf_config = tf.ConfigProto(allow_soft_placement=True)
from keras.metrics import categorical_accuracy as accuracy
with tf.Session(config=tf_config, graph=tf.Graph()) as sess:
initial_context = tf.placeholder(tf.int32, [batch_size_per_chunk, None])
p_for_topp = tf.placeholder(tf.float32, [batch_size_per_chunk])
eos_token = tf.placeholder(tf.int32, [])
min_len = tf.placeholder(tf.int32, [])
tokens, probs = sample(news_config=news_config, initial_context=initial_context,
eos_token=eos_token, min_len=min_len, ignore_ids=None, p_for_topp=p_for_topp,
do_topk=False)
saver = tf.train.Saver()
saver.restore(sess, args.ckpt_fn)
text = "hola, soy un texto re largo utilizado para testear la calidad de prediccion de GPT-2 en español."
line = text
bert_tokens = tokenizer.encode(line)
encoded = bert_tokens.ids
context_formatted = []
context_formatted.extend(encoded)
# Format context end
gens = []
gens_raw = []
gen_probs = []
chunk_i = 0
acc_value = accuracy(tokens, probs)
score = acc_value.eval(feed_dict={initial_context: [context_formatted] * batch_size_per_chunk,
eos_token: args.eos_token, min_len: args.min_len,
p_for_topp: top_p[chunk_i]})
print(score)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment