Skip to content

Instantly share code, notes, and snippets.

@pltrdy
Created July 3, 2017 08:52
Show Gist options
  • Save pltrdy/8d8ce9f4dbcf1793f992a7bab358b44d to your computer and use it in GitHub Desktop.
Save pltrdy/8d8ce9f4dbcf1793f992a7bab358b44d to your computer and use it in GitHub Desktop.
Running PTB LM (10k vocabulary) benchmark with tensor2tensor
#!/bin/bash
set -e
source activate tensorflow
# See what problems, models, and hyperparameter sets are available.
# You can easily swap between them (and add new ones).
#t2t-trainer --registry_help
PROBLEM="lmptb_10k"
MODEL="attention_lm"
HPARAMS="attention_lm_base"
DATA_DIR="./data"
TMP_DIR="/tmp/t2t_datagen"
TRAIN_DIR="attn_lm"
VALID_SRC="$TMP_DIR/simple-examples/data/ptb.valid.txt"
mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR
echo "Parameters:"
printf "* Problem:\t$PROBLEM\n"
printf "* Model:\t$MODEL\n"
printf "* HParams:\t$HPARAMS\n"
printf "* Data:\t$DATA_DIR\n"
printf "* Tmp:\t$TMP_DIR\n"
printf "* Train:\t$TRAIN_DIR\n"
printf "* Cuda device:\t$CUDA_VISIBLE_DEVICES\n"
generate(){
echo "Generating data..."
t2t-datagen \
--data_dir=$DATA_DIR \
--tmp_dir=$TMP_DIR \
--num_shards=100 \
--problem=$PROBLEM
}
train(){
echo "Training..."
python3 tensor2tensor/bin/t2t-trainer \
--model="attention_lm" \
--problems="lmptb_10k" \
--hparams_set="attention_lm_base" \
--output_dir="$TRAIN_DIR" \
--data_dir="$DATA_DIR"
}
decode(){
echo "Decoding..."
DECODE_FILE="$VALID_SRC"
BEAM_SIZE=4
ALPHA="0.6"
t2t-trainer \
--data_dir="$DATA_DIR" \
--problems="$PROBLEM" \
--model=$MODEL \
--hparams_set="$HPARAMS" \
--output_dir="$TRAIN_DIR" \
--train_steps=0 \
--eval_steps=0 \
--decode_from_file="$DECODE_FILE" \
--decode_beam_size="$BEAM_SIZE" \
}
action="$1"
echo "$action"
if [ -z "$action" ]; then
echo "no action. please add one of the following command [gen, train, gentrain, decode]"
elif [ $action = "gen" ]; then
generate
elif [ $action = "train" ]; then
train
elif [ $action = "gentrain" ]; then
generate
train
elif [ $action = "decode" ]; then
decode
fi
@lapolonio
Copy link

is this decode function the intended way to predict based on an input?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment