Skip to content

Instantly share code, notes, and snippets.

@danyaljj
Created April 9, 2021 23:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save danyaljj/9eda0b32014775062b62b25ceb57ec31 to your computer and use it in GitHub Desktop.
Save danyaljj/9eda0b32014775062b62b25ceb57ec31 to your computer and use it in GitHub Desktop.
export TPU_NAME=sihao02
export PROJECT=???
export ZONE=???
export BUCKET=gs://sihao-source/models
PRETRAINED_STEPS=1000000
FINETUNE_STEPS=50000
declare -a sizes=("large")
declare -a tasks=("twitter")
for TASK in "${tasks[@]}"; do
for SIZE in "${sizes[@]}"; do
PRETRAINED_DIR="gs://t5-data/pretrained_models/${SIZE}"
MODEL_DIR="${BUCKET}/${TASK}/${SIZE}"
# Run fine-tuning
python -m t5.models.mesh_transformer_main \
--module_import="train" \
--tpu="${TPU_NAME}" \
--gcp_project="${PROJECT}" \
--tpu_zone="${ZONE}" \
--model_dir="${MODEL_DIR}" \
--gin_file="dataset.gin" \
--gin_file="${PRETRAINED_DIR}/operative_config.gin" \
--gin_param="utils.run.save_checkpoints_steps=5000" \
--gin_param="utils.tpu_mesh_shape.tpu_topology = 'v3-8'" \
--gin_param="MIXTURE_NAME = '${TASK}'" \
--gin_param="utils.run.batch_size=('tokens_per_batch', 4096)" \
--gin_param="utils.tpu_mesh_shape.model_parallelism = 8" \
--gin_param="utils.run.train_steps=$((PRETRAINED_STEPS + FINETUNE_STEPS))" \
--gin_param="utils.run.init_checkpoint='${PRETRAINED_DIR}/model.ckpt-${PRETRAINED_STEPS}'" \
--t5_tfds_data_dir="${BUCKET}/t5-tfds"
# Run eval
python -m t5.models.mesh_transformer_main \
--module_import="train" \
--tpu="${TPU_NAME}" \
--gcp_project="${PROJECT}" \
--tpu_zone="${ZONE}" \
--model_dir="${MODEL_DIR}" \
--gin_file="dataset.gin" \
--gin_file="${MODEL_DIR}/operative_config.gin" \
--gin_file="eval.gin" \
--gin_param="utils.tpu_mesh_shape.tpu_topology = 'v3-8'" \
--gin_param="MIXTURE_NAME = '${TASK}'" \
--gin_param="utils.run.dataset_split = 'dev'" \
--gin_param="utils.tpu_mesh_shape.model_parallelism = 8" \
--gin_param="utils.run.batch_size=('tokens_per_batch', 4096)" \
--gin_param="utils.run.eval_checkpoint_step='all'" \
--t5_tfds_data_dir="${BUCKET}/t5-tfds"
python -m t5.models.mesh_transformer_main \
--module_import="train" \
--tpu="${TPU_NAME}" \
--gcp_project="${PROJECT}" \
--tpu_zone="${ZONE}" \
--model_dir="${MODEL_DIR}" \
--gin_file="dataset.gin" \
--gin_file="${MODEL_DIR}/operative_config.gin" \
--gin_file="eval.gin" \
--gin_param="utils.tpu_mesh_shape.tpu_topology = 'v3-8'" \
--gin_param="MIXTURE_NAME = '${TASK}'" \
--gin_param="utils.run.dataset_split = 'test'" \
--gin_param="utils.tpu_mesh_shape.model_parallelism = 8" \
--gin_param="utils.run.batch_size=('tokens_per_batch', 4096)" \
--gin_param="utils.run.eval_checkpoint_step='all'" \
--t5_tfds_data_dir="${BUCKET}/t5-tfds"
done
done
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment