Last active
November 28, 2019 10:07
-
-
Save wfng92/71834e829f4b99e79b92888a9592457d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tokenization = run_classifier.tokenization | |
base_path = "C:/Users/wfng/path/to/optimized_model" #modify accordingly | |
init_checkpoint = os.path.join(base_path, 'model.ckpt') | |
bert_config_file = os.path.join(base_path, 'bert_config.json') | |
vocab_file = os.path.join(base_path, 'vocab.txt') | |
processor = run_classifier.ColaProcessor() | |
label_list = processor.get_labels() | |
#since the original bert source code combines train, eval and predict in one single configuration, | |
#we need to feed such data during initialization, can be anything as it is needed for run configuration | |
BATCH_SIZE = 8 | |
SAVE_SUMMARY_STEPS = 100 | |
SAVE_CHECKPOINTS_STEPS = 500 | |
OUTPUT_DIR = "./bert_output/output" | |
#variables that needed to be modified | |
labels = ["0", "1", "2"] #modify based on the labels that you have | |
MAX_SEQ_LENGTH = 64 #modify based on the seq length | |
is_lower_case = True #modify based on uncased or cased | |
#variables for configuration | |
tokenization.validate_case_matches_checkpoint(is_lower_case, init_checkpoint) | |
bert_config = run_classifier.modeling.BertConfig.from_json_file(bert_config_file) | |
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=is_lower_case) | |
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 | |
run_config = tf.contrib.tpu.RunConfig( | |
model_dir=OUTPUT_DIR, | |
cluster=None, | |
master=None, | |
save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS, | |
tpu_config=tf.contrib.tpu.TPUConfig( | |
iterations_per_loop=1000, | |
num_shards=8, | |
per_host_input_for_training=is_per_host)) | |
#model | |
model_fn = run_classifier.model_fn_builder( | |
bert_config=bert_config, | |
num_labels=len(label_list), | |
init_checkpoint=init_checkpoint, | |
learning_rate=5e-5, | |
num_train_steps=None, | |
num_warmup_steps=None, | |
use_tpu=False, | |
use_one_hot_embeddings=False) | |
#estimator | |
estimator = tf.contrib.tpu.TPUEstimator( | |
use_tpu=False, | |
model_fn=model_fn, | |
config=run_config, | |
train_batch_size=BATCH_SIZE, | |
eval_batch_size=BATCH_SIZE, | |
predict_batch_size=BATCH_SIZE) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment