Skip to content

Instantly share code, notes, and snippets.

@wfng92
Last active November 28, 2019 10:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wfng92/71834e829f4b99e79b92888a9592457d to your computer and use it in GitHub Desktop.
Save wfng92/71834e829f4b99e79b92888a9592457d to your computer and use it in GitHub Desktop.
tokenization = run_classifier.tokenization
base_path = "C:/Users/wfng/path/to/optimized_model" #modify accordingly
init_checkpoint = os.path.join(base_path, 'model.ckpt')
bert_config_file = os.path.join(base_path, 'bert_config.json')
vocab_file = os.path.join(base_path, 'vocab.txt')
processor = run_classifier.ColaProcessor()
label_list = processor.get_labels()
#since the original bert source code combines train, eval and predict in one single configuration,
#we need to feed such data during initialization, can be anything as it is needed for run configuration
BATCH_SIZE = 8
SAVE_SUMMARY_STEPS = 100
SAVE_CHECKPOINTS_STEPS = 500
OUTPUT_DIR = "./bert_output/output"
#variables that needed to be modified
labels = ["0", "1", "2"] #modify based on the labels that you have
MAX_SEQ_LENGTH = 64 #modify based on the seq length
is_lower_case = True #modify based on uncased or cased
#variables for configuration
tokenization.validate_case_matches_checkpoint(is_lower_case, init_checkpoint)
bert_config = run_classifier.modeling.BertConfig.from_json_file(bert_config_file)
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=is_lower_case)
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig(
model_dir=OUTPUT_DIR,
cluster=None,
master=None,
save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=1000,
num_shards=8,
per_host_input_for_training=is_per_host))
#model
model_fn = run_classifier.model_fn_builder(
bert_config=bert_config,
num_labels=len(label_list),
init_checkpoint=init_checkpoint,
learning_rate=5e-5,
num_train_steps=None,
num_warmup_steps=None,
use_tpu=False,
use_one_hot_embeddings=False)
#estimator
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=False,
model_fn=model_fn,
config=run_config,
train_batch_size=BATCH_SIZE,
eval_batch_size=BATCH_SIZE,
predict_batch_size=BATCH_SIZE)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment