Created
October 31, 2020 01:30
-
-
Save fancyerii/9c9ea297a71cb8b36519705e67638a3e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import numpy as np | |
import collections | |
import matplotlib.pyplot as plt# Colab-only TensorFlow version selector | |
import tensorflow as tf | |
from tensor2tensor import models | |
from tensor2tensor import problems | |
from tensor2tensor.layers import common_layers | |
from tensor2tensor.utils import trainer_lib | |
from tensor2tensor.utils import t2t_model | |
from tensor2tensor.utils import registry | |
from tensor2tensor.utils import metrics | |
from tensor2tensor.utils.trainer_lib import (create_hparams, | |
create_run_config, | |
create_experiment) | |
# Enable TF Eager execution | |
# Other setup | |
Modes = tf.estimator.ModeKeys | |
DATA_DIR = os.path.expanduser("~/t2tcn2/data") # This folder contain the data | |
TMP_DIR = os.path.expanduser("~/t2tcn2/tmp") | |
TRAIN_DIR = os.path.expanduser("~/t2tcn2/train") # This folder contain the model | |
EXPORT_DIR = os.path.expanduser("~/t2tcn2/export") # This folder contain the exported model for production | |
TRANSLATIONS_DIR = os.path.expanduser("~/t2tcn2/translation") # This folder contain all translated sequence | |
EVENT_DIR = os.path.expanduser("~/t2tcn2/event") # Test the BLEU score | |
USR_DIR = os.path.expanduser("~/t2tcn2/user") # This folder contains our data that we want to add | |
import pathlib | |
pathlib.Path(DATA_DIR).mkdir(parents=True, exist_ok=True) | |
pathlib.Path(TMP_DIR).mkdir(parents=True, exist_ok=True) | |
pathlib.Path(TRAIN_DIR).mkdir(parents=True, exist_ok=True) | |
pathlib.Path(EXPORT_DIR).mkdir(parents=True, exist_ok=True) | |
pathlib.Path(TRANSLATIONS_DIR).mkdir(parents=True, exist_ok=True) | |
pathlib.Path(EVENT_DIR).mkdir(parents=True, exist_ok=True) | |
pathlib.Path(USR_DIR).mkdir(parents=True, exist_ok=True) | |
# problems.available() # Show all problems | |
# this is a English-Chinese dataset with 8192 vocabulary | |
PROBLEM = 'translate_enzh_wmt8k' # registry.list_models() # Show all registered models | |
MODEL = 'transformer' # Hyperparameters for the model by default | |
# start with "transformer_base" or 'transformer_base_single_gpu' | |
# if training on a single GPU | |
HPARAMS = 'transformer_base_single_gpu' | |
t2t_problem = problems.problem(PROBLEM) | |
t2t_problem.generate_data(DATA_DIR, TMP_DIR) | |
train_steps = 20000 # Total number of train steps for all Epochs | |
eval_steps = 100 # Number of steps to perform for each evaluation | |
batch_size = 1000 | |
save_checkpoints_steps = 1000 # Save checkpoints every 1000 steps | |
ALPHA = 0.1 # Learning rate | |
schedule = "continuous_train_and_eval"# Init Hparams object | |
hparams = create_hparams(HPARAMS) | |
# Make Changes to Hparams | |
hparams.batch_size = batch_size | |
hparams.learning_rate = ALPHA | |
# train the model | |
RUN_CONFIG = create_run_config( | |
model_dir=TRAIN_DIR, | |
model_name=MODEL, | |
save_checkpoints_steps= save_checkpoints_steps, | |
keep_checkpoint_max=5 | |
) | |
tensorflow_exp_fn = create_experiment( | |
run_config=RUN_CONFIG, | |
hparams=hparams, | |
model_name=MODEL, | |
problem_name=PROBLEM, | |
data_dir=DATA_DIR, | |
train_steps=train_steps, | |
eval_steps=eval_steps, | |
schedule=schedule | |
) | |
tensorflow_exp_fn.continuous_train_and_eval() | |
#tensorflow_exp_fn.train_and_evaluate() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment