Created
August 17, 2022 19:03
-
-
Save leiterenato/fcea49ce36be51ba789f718e2db7c47f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import functools\n", | |
"import os\n", | |
"import seqio\n", | |
"\n", | |
"import t5.data\n", | |
"from t5.evaluation import metrics\n", | |
"from t5.data import preprocessors\n", | |
"from t5.data import postprocessors\n", | |
"\n", | |
"import tensorflow as tf" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Setup variables" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"GIN_FILE = '/home/renatoleite/workspace/export-t5x-tf/examples/export_t511_large_cnn_dailymail.gin'\n", | |
"MODEL_NAME = 't511-large-cnn-dailymail'\n", | |
"BATCH_SIZE = 1\n", | |
"XLA = '-xla'\n", | |
"IS_XLA = True if XLA else False\n", | |
"MODEL_OUTPUT_DIR = f'/home/renatoleite/model_export/{MODEL_NAME}{XLA}/{BATCH_SIZE}'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### [INTERNAL] Export model from command line" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"vscode": { | |
"languageId": "shellscript" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# Internal tool\n", | |
"! python /home/renatoleite/workspace/export-t5x-tf/export.py \\\n", | |
"--gin_file={GIN_FILE} \\\n", | |
"--gin.BATCH_SIZE={BATCH_SIZE} \\\n", | |
"--gin.MODEL_NAME=\\\"{MODEL_NAME}\\\" \\\n", | |
"--gin.MODEL_OUTPUT_DIR=\\\"{MODEL_OUTPUT_DIR}\\\" \\\n", | |
"--gin.export_lib.ExportableModule.jit_compile={IS_XLA}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### XSUM: Retrieve dataset\n", | |
"Execute the following 2 cells ONLY if you consider using XSUM dataset." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"DEFAULT_OUTPUT_FEATURES = {\n", | |
" \"inputs\": seqio.Feature(\n", | |
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True,\n", | |
" required=False),\n", | |
" \"targets\": seqio.Feature(\n", | |
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True)\n", | |
"}\n", | |
"\n", | |
"paths = {\n", | |
" 'train': 'gs://rl-language/dataset/xsum/xsum_train.tfrecord',\n", | |
" 'test': 'gs://rl-language/dataset/xsum/xsum_test.tfrecord',\n", | |
" 'validation': 'gs://rl-language/dataset/xsum/xsum_validation.tfrecord'\n", | |
"}\n", | |
"\n", | |
"# Task definition\n", | |
"task = seqio.Task(\n", | |
" 'xsum',\n", | |
" source=seqio.TFExampleDataSource(\n", | |
" split_to_filepattern=paths, \n", | |
" feature_description={\n", | |
" 'document': tf.io.FixedLenFeature(shape=(), dtype=tf.string),\n", | |
" 'summary': tf.io.FixedLenFeature(shape=(), dtype=tf.string)\n", | |
" }\n", | |
" ),\n", | |
" preprocessors=[\n", | |
" functools.partial(\n", | |
" preprocessors.summarize,\n", | |
" article_key=\"document\",\n", | |
" summary_key=\"summary\"),\n", | |
" seqio.preprocessors.tokenize,\n", | |
" seqio.CacheDatasetPlaceholder(),\n", | |
" seqio.preprocessors.append_eos_after_trim,\n", | |
" ],\n", | |
" metric_fns=[metrics.rouge],\n", | |
" output_features=DEFAULT_OUTPUT_FEATURES)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Create a batch for inference\n", | |
"task_ds = task.get_dataset(sequence_length={'document':1024, 'summary':64}, split='train').as_numpy_iterator()\n", | |
"\n", | |
"example = []\n", | |
"for i in range(BATCH_SIZE):\n", | |
" example.append(task_ds.next())\n", | |
"\n", | |
"test = tf.constant([e['inputs_pretokenized'] for e in example], dtype=tf.string)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### CNN Daily: Retrive dataset\n", | |
"Execute the following cells ONLY if you consider using CNN Dailymain dataset. \n", | |
"Before executing the following cells, in case you are considering using CNN Daily dataset, copy the files from `gs://jk-t5x-staging/datasets/cnn_dailymail/3.4.0/*` to `~/tensorflow_datasets/cnn_dailymail/3.4.0`. You must create this directory first.\n", | |
"\n", | |
"`gsutil -m cp -r gs://jk-t5x-staging/datasets/cnn_dailymail/3.4.0/* ~/tensorflow_datasets/cnn_dailymail/3.4.0`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"DEFAULT_OUTPUT_FEATURES = {\n", | |
" \"inputs\": seqio.Feature(\n", | |
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True,\n", | |
" required=False),\n", | |
" \"targets\": seqio.Feature(\n", | |
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True)\n", | |
"}\n", | |
"\n", | |
"# Task definition\n", | |
"task = seqio.Task(\n", | |
" \"cnn_dailymail_custom\",\n", | |
" source=seqio.TfdsDataSource(tfds_name=\"cnn_dailymail:3.4.0\"),\n", | |
" preprocessors=[\n", | |
" functools.partial(\n", | |
" preprocessors.summarize,\n", | |
" article_key=\"article\",\n", | |
" summary_key=\"highlights\"),\n", | |
" seqio.preprocessors.tokenize,\n", | |
" seqio.CacheDatasetPlaceholder(),\n", | |
" seqio.preprocessors.append_eos_after_trim,\n", | |
" ],\n", | |
" metric_fns=[metrics.rouge],\n", | |
" output_features=DEFAULT_OUTPUT_FEATURES)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`get_dataset` from SeqIO returns a `tf.data.Dataset` with preprocessed data. \n", | |
"After calling this function, treat the return as a regular `tf.data.Dataset`. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Create a batch for inference\n", | |
"task_ds = task.get_dataset(sequence_length={'article':512, 'highlights':64}, split='train').as_numpy_iterator()\n", | |
"\n", | |
"example = []\n", | |
"for i in range(BATCH_SIZE):\n", | |
" example.append(task_ds.next())\n", | |
"\n", | |
"test = tf.constant([e['inputs_pretokenized'] for e in example], dtype=tf.string)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Load model and run inference" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = tf.saved_model.load(MODEL_OUTPUT_DIR)\n", | |
"infer = model.signatures[\"serving_default\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"results = infer(text_batch=test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for e, r in zip(example, results['output_0']):\n", | |
" print('Original: ', e['targets_pretokenized'])\n", | |
" print('Infered: ', r[0].numpy())\n", | |
" print()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3.8.13 ('t5x-base')", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.13" | |
}, | |
"orig_nbformat": 4, | |
"vscode": { | |
"interpreter": { | |
"hash": "cb70e41daff3c8e885354b3a6acd41d74f1030e2e06d05a657a05712cd24b2a0" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment