Skip to content

Instantly share code, notes, and snippets.

@fancyerii
Created October 31, 2020 01:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fancyerii/9c9ea297a71cb8b36519705e67638a3e to your computer and use it in GitHub Desktop.
Save fancyerii/9c9ea297a71cb8b36519705e67638a3e to your computer and use it in GitHub Desktop.
import os
import sys
import numpy as np
import collections
import matplotlib.pyplot as plt# Colab-only TensorFlow version selector
import tensorflow as tf
from tensor2tensor import models
from tensor2tensor import problems
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import registry
from tensor2tensor.utils import metrics
from tensor2tensor.utils.trainer_lib import (create_hparams,
create_run_config,
create_experiment)
# Enable TF Eager execution
# Other setup
Modes = tf.estimator.ModeKeys
DATA_DIR = os.path.expanduser("~/t2tcn2/data") # This folder contain the data
TMP_DIR = os.path.expanduser("~/t2tcn2/tmp")
TRAIN_DIR = os.path.expanduser("~/t2tcn2/train") # This folder contain the model
EXPORT_DIR = os.path.expanduser("~/t2tcn2/export") # This folder contain the exported model for production
TRANSLATIONS_DIR = os.path.expanduser("~/t2tcn2/translation") # This folder contain all translated sequence
EVENT_DIR = os.path.expanduser("~/t2tcn2/event") # Test the BLEU score
USR_DIR = os.path.expanduser("~/t2tcn2/user") # This folder contains our data that we want to add
import pathlib
pathlib.Path(DATA_DIR).mkdir(parents=True, exist_ok=True)
pathlib.Path(TMP_DIR).mkdir(parents=True, exist_ok=True)
pathlib.Path(TRAIN_DIR).mkdir(parents=True, exist_ok=True)
pathlib.Path(EXPORT_DIR).mkdir(parents=True, exist_ok=True)
pathlib.Path(TRANSLATIONS_DIR).mkdir(parents=True, exist_ok=True)
pathlib.Path(EVENT_DIR).mkdir(parents=True, exist_ok=True)
pathlib.Path(USR_DIR).mkdir(parents=True, exist_ok=True)
# problems.available() # Show all problems
# this is a English-Chinese dataset with 8192 vocabulary
PROBLEM = 'translate_enzh_wmt8k' # registry.list_models() # Show all registered models
MODEL = 'transformer' # Hyperparameters for the model by default
# start with "transformer_base" or 'transformer_base_single_gpu'
# if training on a single GPU
HPARAMS = 'transformer_base_single_gpu'
t2t_problem = problems.problem(PROBLEM)
t2t_problem.generate_data(DATA_DIR, TMP_DIR)
train_steps = 20000 # Total number of train steps for all Epochs
eval_steps = 100 # Number of steps to perform for each evaluation
batch_size = 1000
save_checkpoints_steps = 1000 # Save checkpoints every 1000 steps
ALPHA = 0.1 # Learning rate
schedule = "continuous_train_and_eval"# Init Hparams object
hparams = create_hparams(HPARAMS)
# Make Changes to Hparams
hparams.batch_size = batch_size
hparams.learning_rate = ALPHA
# train the model
RUN_CONFIG = create_run_config(
model_dir=TRAIN_DIR,
model_name=MODEL,
save_checkpoints_steps= save_checkpoints_steps,
keep_checkpoint_max=5
)
tensorflow_exp_fn = create_experiment(
run_config=RUN_CONFIG,
hparams=hparams,
model_name=MODEL,
problem_name=PROBLEM,
data_dir=DATA_DIR,
train_steps=train_steps,
eval_steps=eval_steps,
schedule=schedule
)
tensorflow_exp_fn.continuous_train_and_eval()
#tensorflow_exp_fn.train_and_evaluate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment