Skip to content

Instantly share code, notes, and snippets.

@leiterenato
Created March 2, 2023 19:30
Show Gist options
  • Save leiterenato/8d97b13df0341801acc1344a329ba2e8 to your computer and use it in GitHub Desktop.
Save leiterenato/8d97b13df0341801acc1344a329ba2e8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import cache_tasks_main\n",
"from typing import Mapping\n",
"import seqio\n",
"\n",
"import tensorflow as tf\n",
"\n",
"import functools\n",
"import t5.data\n",
"from t5.data.preprocessors import translate\n",
"from t5.evaluation import metrics"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# SeqIO"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### What is SeqIO?\n",
"\n",
"SeqIO is a library for processing sequential data to be fed into downstream sequence models. \n",
"It uses **tf.data.Dataset** to create scalable data pipelines but requires minimal use of TensorFlow. \n",
"In particular, with one line of code, the returned dataset can be transformed to a **numpy iterator** and hence it is fully compatible with other frameworks such as JAX or PyTorch."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"At a high level, we use SeqIO with the following steps:\n",
" - Define a Task (and optionally a Mixture).\n",
" - Define (or use an existing) a FeatureConverter based on the model architecture.\n",
" - Use the top-level function seqio.get_dataset to obtain the tf.data.Dataset instance."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example of a Task for translation. It combines:\n",
"# - Raw data source\n",
"# - One or more preprocessing steps\n",
"# - A vocabulary to tokenize/detokenize each preprocessed feature for the model\n",
"# - A postprocessor to convert detokenized model outputs into a format for evaluation\n",
"# - One or more metrics to evaluate with\n",
"\n",
"# Add to global registry\n",
"seqio.TaskRegistry.add(\n",
" \"wmt19_ende\", # Unique name\n",
" seqio.TfdsDataSource(tfds_name=\"wmt19_translate/de-en:1.0.0\"), # Load as tf.data.Dataset\n",
" preprocessors=[\n",
" functools.partial(\n",
" translate, \n",
" source_language='en', target_language='de'),\n",
" seqio.preprocessors.tokenize, \n",
" seqio.preprocessors.append_eos\n",
" ],\n",
" output_features={\n",
" 'inputs':\n",
" seqio.Feature(\n",
" seqio.SentencePieceVocabulary('/path/to/inputs/vocab'),\n",
" add_eos=False,\n",
" dtype=tf.int32),\n",
" 'targets':\n",
" seqio.Feature(\n",
" seqio.SentencePieceVocabulary('/path/to/targets/vocab'),\n",
" add_eos=True,\n",
" dtype=tf.int32),\n",
" },\n",
" metric_fns=[metrics.bleu])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load raw data in many formats as a tf.data.Dataset\n",
"\n",
"# Tensorflow Datasets\n",
"seqio.TfdsDataSource(...)\n",
"\n",
"# TextLineDataSource\n",
"seqio.TextLineDataSource(...)\n",
"\n",
"# TFExampleDataSource\n",
"seqio.TFExampleDataSource(...)\n",
"\n",
"# FunctionDataSource\n",
"seqio.FunctionDataSource(...)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Output field Map[str:seqio.Feature]\n",
"# Defines what the Task is expected to produce in its output example\n",
"\n",
"# Contains:\n",
"# - vocabulary (how to tokenize / detokenize)\n",
"# - add_eos\n",
"# - dtype\n",
"output_features={\n",
" 'inputs':\n",
" seqio.Feature(\n",
" seqio.SentencePieceVocabulary('/path/to/inputs/vocab'),\n",
" add_eos=False,\n",
" dtype=tf.int32),\n",
" 'targets':\n",
" seqio.Feature(\n",
" seqio.SentencePieceVocabulary('/path/to/targets/vocab'),\n",
" add_eos=True,\n",
" dtype=tf.int32)\n",
"}\n",
"\n",
"# Decoder only?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Preprocessors: Transform tf.data.Dataset into a new tf.data.Dataset\n",
"preprocessors=[\n",
" functools.partial(\n",
" translate, source_language='en', target_language='de'),\n",
" seqio.preprocessors.tokenize, \n",
" seqio.preprocessors.append_eos\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@seqio.map_over_dataset\n",
"def translate(ex: Mapping[str, tf.Tensor],\n",
" source_language: str,\n",
" target_language: str) -> Mapping[str, tf.Tensor]:\n",
" \"\"\"Convert a translation dataset to a text2text pair.\n",
"\n",
" For example, say the dataset returns examples of this format:\n",
" {'de': 'Das ist gut.', 'en': 'That is good.'}\n",
" If source_language = 'de', target_language = 'en', then the outputs will have\n",
" the format:\n",
" {'inputs': 'translate German to English: Das ist gut.',\n",
" 'targets': 'That is good.'}\n",
"\n",
" Args:\n",
" ex: an example to process.\n",
" source_language: source language code (e.g. 'en') to translate from.\n",
" target_language: target language code (e.g. 'de') to translate to.\n",
"\n",
" Returns:\n",
" A preprocessed example with the format listed above.\n",
" \"\"\"\n",
" src_str = f'translate {source_language}'\n",
" tgt_str = f' to {target_language}: '\n",
" return {\n",
" 'inputs': tf.strings.join([src_str, tgt_str, ex[source_language]]),\n",
" 'targets': ex[target_language],\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def translate(dataset: tf.data.Dataset,\n",
" source_language: str,\n",
" target_language: str) -> tf.data.Dataset:\n",
" def _translate(ex: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:\n",
" \"\"\"Convert a translation example to a text2text pair.\n",
"\n",
" For example, say the dataset returns examples of this format:\n",
" {'de': 'Das ist gut.', 'en': 'That is good.'}\n",
" If source_language = 'de', target_language = 'en', then the outputs will have\n",
" the format:\n",
" {'inputs': 'translate de to en: Das ist gut.',\n",
" 'targets': 'That is good.'}\n",
"\n",
" Args:\n",
" ex: an example to process.\n",
" source_language: source language code (e.g. 'en') to translate from.\n",
" target_language: target language code (e.g. 'de') to translate to.\n",
"\n",
" Returns:\n",
" A preprocessed example with the format listed above.\n",
" \"\"\"\n",
" src_str = f'translate {source_language}'\n",
" tgt_str = f' to {target_language}: '\n",
" return {\n",
" 'inputs': tf.strings.join([src_str, tgt_str, ex[source_language]]),\n",
" 'targets': ex[target_language],\n",
" }\n",
"\n",
" return dataset.map(_translate,\n",
" num_parallel_calls=tf.data.AUTOTUNE)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Post processor and Metrics\n",
"\n",
"Before passing predictions to the evaluation function, that data can be processed with a python function."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Prediction metrics are computed using the postprocessed targets and model outputs (predictions). The args must be named `targets` and `predictions`.\n",
"\n",
"Let's look at the metric function used for \"wmt19_ende\" task. A standard metric for the translation task is BLEU and we use sacrebleu implementation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tqa_open_preprocessor(\n",
" dataset: tf.data.Dataset,\n",
" prefix: str = \"trivia_qa question: \"\n",
") -> tf.data.Dataset:\n",
"\n",
" @seqio.map_over_dataset\n",
" def tqa_map(ex: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:\n",
" \"\"\"Map TriviaQA example to text-to-text example.\"\"\"\n",
" return {\n",
" \"inputs\": prefix + ex[\"question\"],\n",
" \"targets\": ex[\"answer\"][\"value\"],\n",
" \"answers\": ex[\"answer\"][\"aliases\"],\n",
" }\n",
"\n",
" return tqa_map(dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"postprocess_fn=tqa_open_preprocessor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Metrics function\n",
"metric_fns=[metrics.bleu]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"seqio.TaskRegistry.add(\n",
" \"wmt19_ende\", # Unique name\n",
" seqio.TfdsDataSource(tfds_name=\"wmt19_translate/de-en:1.0.0\"), # Load as tf.data.Dataset\n",
" preprocessors=[\n",
" functools.partial(\n",
" translate, \n",
" source_language='en', target_language='de'),\n",
" seqio.preprocessors.tokenize, \n",
" seqio.preprocessors.append_eos\n",
" ],\n",
" output_features={\n",
" 'inputs':\n",
" seqio.Feature(\n",
" seqio.SentencePieceVocabulary('/path/to/inputs/vocab'),\n",
" add_eos=False,\n",
" dtype=tf.int32),\n",
" 'targets':\n",
" seqio.Feature(\n",
" seqio.SentencePieceVocabulary('/path/to/targets/vocab'),\n",
" add_eos=True,\n",
" dtype=tf.int32),\n",
" },\n",
" metric_fns=[metrics.bleu])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define Mixtures\n",
"seqio.MixtureRegistry.add(\n",
" \"mix2\",\n",
" [\"task1\", \"task2\"]\n",
")\n",
"\n",
"# Mixture in your Mixture\n",
"seqio.MixtureRegistry.add(\n",
" \"mix3\",\n",
" [\"mix1\", \"task1\", \"task3\"]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Getting the dataset\n",
"dataset = seqio.get_mixture_or_task(\"mix1\").get_dataset(\n",
" sequence_length={\"inputs\": 256, \"targets\": 128},\n",
" split=\"train\",\n",
" shuffle=True,\n",
" num_epochs=1,\n",
" shard_info=seqio.ShardInfo(index=0, num_shards=10), # load a deterministic subset of the dataset\n",
" use_cached=False, # load from a pre-cached task or process on-the-fly\n",
" seed=42 # deterministic shuffling and (stateless) stochastic ops\n",
")\n",
"\n",
"# Print the first 5 examples.\n",
"for _, ex in zip(range(5), dataset.as_numpy_iterator()):\n",
" print(ex)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### (optional) Offline Caching\n",
"\n",
"For improved performance at load time and avoid redundant computations for commonly used tasks, you can pre-cache your Task with all or part of the preprocessing done in advance of training.\n",
"\n",
"Add:\n",
"\n",
"```python\n",
"seqio.CacheDatasetPlaceholder(required=False)\n",
"```\n",
"\n",
"as one of the steps in your preprocessing pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run as an Apache Beam job on Cloud DataFlow\n",
"cache_tasks_main"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, you are ready to load the cached version of your Task (or Mixture) containing it. \n",
"You will need to add the path to the directory you passed to `--output_cache_dir` via `seqio.add_global_cache_dirs([\"/my/cache/dir\"])`. \n",
"\n",
"Now when you call `task_or_mixture.get_dataset(..., use_cached=True)`, the data will be loaded from the cache directory instead of the raw data source."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## End to end sample"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DEFAULT_OUTPUT_FEATURES = {\n",
" \"inputs\": seqio.Feature(\n",
" vocabulary=t5.data.get_default_vocabulary(), \n",
" add_eos=True,\n",
" required=False),\n",
" \"targets\": seqio.Feature(\n",
" vocabulary=t5.data.get_default_vocabulary(), \n",
" add_eos=True)\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@seqio.map_over_dataset\n",
"def summarize(x, article_key, summary_key):\n",
" \"\"\"Convert a summarization dataset to a text2text pair.\n",
"\n",
" For example, say the dataset returns examples of this format:\n",
" {'article': <article>, 'highlights': <summary>}\n",
" If article_key = 'article', summary_key = 'highlights', then the outputs will\n",
" have the format:\n",
" {'inputs': 'summarize': <article>, 'targets': <summary>}\n",
"\n",
" Args:\n",
" x: an example to process.\n",
" article_key: the feature key for the article to summarize.\n",
" summary_key: the feature key for the target summary.\n",
" Returns:\n",
" A preprocessed example with the format listed above.\n",
" \"\"\"\n",
" strs_to_join = ['summarize:', x[article_key]]\n",
" return {\n",
" 'inputs': tf.strings.join(strs_to_join, separator=' '),\n",
" 'targets': x[summary_key],\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"task = seqio.Task(\n",
" name='custom_xsum',\n",
" source=seqio.TFExampleDataSource(\n",
" split_to_filepattern={\n",
" 'validation': 'gs://rl-llm-export-ft/xsum-dataset/validation.tfrecord'\n",
" },\n",
" feature_description={\n",
" 'document': tf.io.FixedLenFeature(shape=(), dtype=tf.string),\n",
" 'summary': tf.io.FixedLenFeature(shape=(), dtype=tf.string)\n",
" }\n",
" ),\n",
" preprocessors=[\n",
" functools.partial(\n",
" summarize,\n",
" article_key=\"document\",\n",
" summary_key=\"summary\"),\n",
" seqio.preprocessors.tokenize,\n",
" seqio.CacheDatasetPlaceholder(),\n",
" seqio.preprocessors.append_eos_after_trim,\n",
" ],\n",
" metric_fns=[metrics.rouge],\n",
" output_features=DEFAULT_OUTPUT_FEATURES\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"task_ds = task.get_dataset(sequence_length={'document':512, 'summary':128}, split='validation').as_numpy_iterator()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"next(iter(task_ds))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# SeqIO - Feature Converters"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Tasks (task features): Datasets => Provide model-specific features (e.g.: generic inputs and outputs)\n",
"\n",
"Feature Converters (model features): Transform the model-agnostic features => model-specific features"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Raw data\n",
"\"That is good\\tDas ist gut.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# After preprocessing from SeqIO\n",
"{\"inputs\": \"translate English to German: That is good.\",\n",
" \"targets\": \"Das ist gut.\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Sample output\n",
"dataset = [{\"inputs\": [7, 8, 5], \"targets\": [8, 4, 9, 3, 1]}]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# FeatureConverter: Converts to model-specific features (padding, packing)\n",
"# Transformer: Encoder - Decoder\n",
"converted_dataset = [{\n",
" \"encoder_input_tokens\": [7, 8, 5, 1, 8, 4, 9, 3, 1, 0],\n",
" \"encoder_segment_ids\": [1, 1, 1, 1, 2, 2, 2, 2, 2, 0],\n",
" \"encoder_positions\": [0, 1, 2, 3, 0, 1, 2, 3, 4, 0],\n",
" \"decoder_target_tokens\": [3, 9, 1, 4, 1, 0, 0],\n",
" \"decoder_input_tokens\": [0, 3, 9, 0, 4, 0, 0],\n",
" \"decoder_loss_weights\": [1, 1, 1, 1, 1, 0, 0],\n",
" \"decoder_positions\": [0, 1, 2, 0, 1, 0, 0],\n",
" \"decoder_segment_ids\": [1, 1, 1, 2, 2, 0, 0],\n",
"}]\n",
"\n",
"# encoder-decoder, decoder-only and encoder-only"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset: tf.data.Dataset = seqio.get_dataset(\n",
" mixture_or_task_name=\"wmt_t2t_ende_v003\",\n",
" task_feature_lengths={\"inputs\": 32, \"targets\": 32},\n",
" dataset_split=\"train\",\n",
" shuffle=True,\n",
" feature_converter=seqio.EncDecFeatureConverter(pack=True)\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tasks defined for this solution"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tasks import custom_tasks"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment