Skip to content

Instantly share code, notes, and snippets.

@leiterenato
Created August 17, 2022 19:03
Show Gist options
  • Save leiterenato/fcea49ce36be51ba789f718e2db7c47f to your computer and use it in GitHub Desktop.
Save leiterenato/fcea49ce36be51ba789f718e2db7c47f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
"import os\n",
"import seqio\n",
"\n",
"import t5.data\n",
"from t5.evaluation import metrics\n",
"from t5.data import preprocessors\n",
"from t5.data import postprocessors\n",
"\n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"GIN_FILE = '/home/renatoleite/workspace/export-t5x-tf/examples/export_t511_large_cnn_dailymail.gin'\n",
"MODEL_NAME = 't511-large-cnn-dailymail'\n",
"BATCH_SIZE = 1\n",
"XLA = '-xla'\n",
"IS_XLA = True if XLA else False\n",
"MODEL_OUTPUT_DIR = f'/home/renatoleite/model_export/{MODEL_NAME}{XLA}/{BATCH_SIZE}'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [INTERNAL] Export model from command line"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"# Internal tool\n",
"! python /home/renatoleite/workspace/export-t5x-tf/export.py \\\n",
"--gin_file={GIN_FILE} \\\n",
"--gin.BATCH_SIZE={BATCH_SIZE} \\\n",
"--gin.MODEL_NAME=\\\"{MODEL_NAME}\\\" \\\n",
"--gin.MODEL_OUTPUT_DIR=\\\"{MODEL_OUTPUT_DIR}\\\" \\\n",
"--gin.export_lib.ExportableModule.jit_compile={IS_XLA}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### XSUM: Retrieve dataset\n",
"Execute the following 2 cells ONLY if you consider using XSUM dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DEFAULT_OUTPUT_FEATURES = {\n",
" \"inputs\": seqio.Feature(\n",
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True,\n",
" required=False),\n",
" \"targets\": seqio.Feature(\n",
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True)\n",
"}\n",
"\n",
"paths = {\n",
" 'train': 'gs://rl-language/dataset/xsum/xsum_train.tfrecord',\n",
" 'test': 'gs://rl-language/dataset/xsum/xsum_test.tfrecord',\n",
" 'validation': 'gs://rl-language/dataset/xsum/xsum_validation.tfrecord'\n",
"}\n",
"\n",
"# Task definition\n",
"task = seqio.Task(\n",
" 'xsum',\n",
" source=seqio.TFExampleDataSource(\n",
" split_to_filepattern=paths, \n",
" feature_description={\n",
" 'document': tf.io.FixedLenFeature(shape=(), dtype=tf.string),\n",
" 'summary': tf.io.FixedLenFeature(shape=(), dtype=tf.string)\n",
" }\n",
" ),\n",
" preprocessors=[\n",
" functools.partial(\n",
" preprocessors.summarize,\n",
" article_key=\"document\",\n",
" summary_key=\"summary\"),\n",
" seqio.preprocessors.tokenize,\n",
" seqio.CacheDatasetPlaceholder(),\n",
" seqio.preprocessors.append_eos_after_trim,\n",
" ],\n",
" metric_fns=[metrics.rouge],\n",
" output_features=DEFAULT_OUTPUT_FEATURES)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a batch for inference\n",
"task_ds = task.get_dataset(sequence_length={'document':1024, 'summary':64}, split='train').as_numpy_iterator()\n",
"\n",
"example = []\n",
"for i in range(BATCH_SIZE):\n",
" example.append(task_ds.next())\n",
"\n",
"test = tf.constant([e['inputs_pretokenized'] for e in example], dtype=tf.string)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CNN Daily: Retrive dataset\n",
"Execute the following cells ONLY if you consider using CNN Dailymain dataset. \n",
"Before executing the following cells, in case you are considering using CNN Daily dataset, copy the files from `gs://jk-t5x-staging/datasets/cnn_dailymail/3.4.0/*` to `~/tensorflow_datasets/cnn_dailymail/3.4.0`. You must create this directory first.\n",
"\n",
"`gsutil -m cp -r gs://jk-t5x-staging/datasets/cnn_dailymail/3.4.0/* ~/tensorflow_datasets/cnn_dailymail/3.4.0`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DEFAULT_OUTPUT_FEATURES = {\n",
" \"inputs\": seqio.Feature(\n",
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True,\n",
" required=False),\n",
" \"targets\": seqio.Feature(\n",
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True)\n",
"}\n",
"\n",
"# Task definition\n",
"task = seqio.Task(\n",
" \"cnn_dailymail_custom\",\n",
" source=seqio.TfdsDataSource(tfds_name=\"cnn_dailymail:3.4.0\"),\n",
" preprocessors=[\n",
" functools.partial(\n",
" preprocessors.summarize,\n",
" article_key=\"article\",\n",
" summary_key=\"highlights\"),\n",
" seqio.preprocessors.tokenize,\n",
" seqio.CacheDatasetPlaceholder(),\n",
" seqio.preprocessors.append_eos_after_trim,\n",
" ],\n",
" metric_fns=[metrics.rouge],\n",
" output_features=DEFAULT_OUTPUT_FEATURES)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`get_dataset` from SeqIO returns a `tf.data.Dataset` with preprocessed data. \n",
"After calling this function, treat the return as a regular `tf.data.Dataset`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a batch for inference\n",
"task_ds = task.get_dataset(sequence_length={'article':512, 'highlights':64}, split='train').as_numpy_iterator()\n",
"\n",
"example = []\n",
"for i in range(BATCH_SIZE):\n",
" example.append(task_ds.next())\n",
"\n",
"test = tf.constant([e['inputs_pretokenized'] for e in example], dtype=tf.string)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load model and run inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = tf.saved_model.load(MODEL_OUTPUT_DIR)\n",
"infer = model.signatures[\"serving_default\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results = infer(text_batch=test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for e, r in zip(example, results['output_0']):\n",
" print('Original: ', e['targets_pretokenized'])\n",
" print('Infered: ', r[0].numpy())\n",
" print()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('t5x-base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "cb70e41daff3c8e885354b3a6acd41d74f1030e2e06d05a657a05712cd24b2a0"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment