Skip to content

Instantly share code, notes, and snippets.

@leiterenato
Created March 1, 2023 17:44
Show Gist options
  • Save leiterenato/b20b67d99349c427adb1dfec0e45dc7c to your computer and use it in GitHub Desktop.
Save leiterenato/b20b67d99349c427adb1dfec0e45dc7c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import cache_tasks_main\n",
"from typing import Mapping\n",
"import seqio\n",
"\n",
"import tensorflow as tf\n",
"\n",
"import functools\n",
"import t5.data\n",
"from t5.data.preprocessors import translate\n",
"from t5.evaluation import metrics"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# SeqIO"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### What is SeqIO?\n",
"\n",
"SeqIO is a library for processing sequential data to be fed into downstream sequence models. \n",
"It uses **tf.data.Dataset** to create scalable data pipelines but requires minimal use of TensorFlow. \n",
"In particular, with one line of code, the returned dataset can be transformed to a **numpy iterator** and hence it is fully compatible with other frameworks such as JAX or PyTorch."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"At a high level, we use SeqIO with the following steps:\n",
" - Define a Task (and optionally a Mixture).\n",
" - Define (or use an existing) a FeatureConverter based on the model architecture.\n",
" - Use the top-level function seqio.get_dataset to obtain the tf.data.Dataset instance."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example of a Task for translation. It combines:\n",
"# - Raw data source\n",
"# - One or more preprocessing steps\n",
"# - A vocabulary to tokenize/detokenize each preprocessed feature for the model\n",
"# - A postprocessor to convert detokenized model outputs into a format for evaluation\n",
"# - One or more metrics to evaluate with\n",
"\n",
"# Add to global registry\n",
"seqio.TaskRegistry.add(\n",
" \"wmt19_ende\", # Unique name\n",
" seqio.TfdsDataSource(tfds_name=\"wmt19_translate/de-en:1.0.0\"), # Load as tf.data.Dataset\n",
" preprocessors=[\n",
" functools.partial(\n",
" translate, \n",
" source_language='en', target_language='de'),\n",
" seqio.preprocessors.tokenize, \n",
" seqio.preprocessors.append_eos\n",
" ],\n",
" output_features={\n",
" 'inputs':\n",
" seqio.Feature(\n",
" seqio.SentencePieceVocabulary('/path/to/inputs/vocab'),\n",
" add_eos=False,\n",
" dtype=tf.int32),\n",
" 'targets':\n",
" seqio.Feature(\n",
" seqio.SentencePieceVocabulary('/path/to/targets/vocab'),\n",
" add_eos=True,\n",
" dtype=tf.int32),\n",
" },\n",
" metric_fns=[metrics.bleu])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load raw data in many formats as a tf.data.Dataset\n",
"\n",
"# Tensorflow Datasets\n",
"seqio.TfdsDataSource(...)\n",
"\n",
"# TextLineDataSource\n",
"seqio.TextLineDataSource(...)\n",
"\n",
"# TFExampleDataSource\n",
"seqio.TFExampleDataSource(...)\n",
"\n",
"# FunctionDataSource\n",
"seqio.FunctionDataSource(...)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Output field Map[str:seqi.Feature]\n",
"# Defines what the Task is expected to produce in its output example\n",
"\n",
"# Contains:\n",
"# - vocabulary (how to tokenize / detokenize)\n",
"# - add_eos\n",
"# - dtype\n",
"output_features={\n",
" 'inputs':\n",
" seqio.Feature(\n",
" seqio.SentencePieceVocabulary('/path/to/inputs/vocab'),\n",
" add_eos=False,\n",
" dtype=tf.int32),\n",
" 'targets':\n",
" seqio.Feature(\n",
" seqio.SentencePieceVocabulary('/path/to/targets/vocab'),\n",
" add_eos=True,\n",
" dtype=tf.int32)\n",
"}\n",
"\n",
"# Decoder only?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Preprocessors: Transform tf.data.Dataset into a new tf.data.Dataset\n",
"preprocessors=[\n",
" functools.partial(\n",
" translate, source_language='en', target_language='de'),\n",
" seqio.preprocessors.tokenize, seqio.preprocessors.append_eos\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def translate(dataset: tf.data.Dataset,\n",
" source_language: str,\n",
" target_language: str) -> tf.data.Dataset:\n",
" def _translate(ex: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:\n",
" \"\"\"Convert a translation example to a text2text pair.\n",
"\n",
" For example, say the dataset returns examples of this format:\n",
" {'de': 'Das ist gut.', 'en': 'That is good.'}\n",
" If source_language = 'de', target_language = 'en', then the outputs will have\n",
" the format:\n",
" {'inputs': 'translate de to en: Das ist gut.',\n",
" 'targets': 'That is good.'}\n",
"\n",
" Args:\n",
" ex: an example to process.\n",
" source_language: source language code (e.g. 'en') to translate from.\n",
" target_language: target language code (e.g. 'de') to translate to.\n",
"\n",
" Returns:\n",
" A preprocessed example with the format listed above.\n",
" \"\"\"\n",
" src_str = f'translate {source_language}'\n",
" tgt_str = f' to {target_language}: '\n",
" return {\n",
" 'inputs': tf.strings.join([src_str, tgt_str, ex[source_language]]),\n",
" 'targets': ex[target_language],\n",
" }\n",
"\n",
" return dataset.map(_translate,\n",
" num_parallel_calls=tf.data.AUTOTUNE)\n",
"\n",
"# Input\n",
"# {'de': 'Das ist gut.', 'en': 'That is good.'} \n",
"\n",
"# Output\n",
"# { \n",
"# 'input': 'translate de to en: Das ist gut.', \n",
"# 'target': 'That is good.' \n",
"#}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@seqio.map_over_dataset\n",
"def translate(ex: Mapping[str, tf.Tensor],\n",
" source_language: str,\n",
" target_language: str) -> Mapping[str, tf.Tensor]:\n",
" \"\"\"Convert a translation dataset to a text2text pair.\n",
"\n",
" For example, say the dataset returns examples of this format:\n",
" {'de': 'Das ist gut.', 'en': 'That is good.'}\n",
" If source_language = 'de', target_language = 'en', then the outputs will have\n",
" the format:\n",
" {'inputs': 'translate German to English: Das ist gut.',\n",
" 'targets': 'That is good.'}\n",
"\n",
" Args:\n",
" ex: an example to process.\n",
" source_language: source language code (e.g. 'en') to translate from.\n",
" target_language: target language code (e.g. 'de') to translate to.\n",
"\n",
" Returns:\n",
" A preprocessed example with the format listed above.\n",
" \"\"\"\n",
" src_str = f'translate {source_language}'\n",
" tgt_str = f' to {target_language}: '\n",
" return {\n",
" 'inputs': tf.strings.join([src_str, tgt_str, ex[source_language]]),\n",
" 'targets': ex[target_language],\n",
" }"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Post processor and Metrics\n",
"\n",
"Before passing predictions to the evaluation function, that data can be processed with a python function."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Prediction metrics are computed using the postprocessed targets and model outputs (predictions). The args must be named `targets` and `predictions`.\n",
"\n",
"Let's look at the metric function used for \"wmt19_ende\" task. A standard metric for the translation task is BLEU and we use sacrebleu implementation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tqa_open_preprocessor(\n",
" dataset: tf.data.Dataset,\n",
" prefix: str = \"trivia_qa question: \"\n",
") -> tf.data.Dataset:\n",
"\n",
" @seqio.map_over_dataset\n",
" def tqa_map(ex: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:\n",
" \"\"\"Map TriviaQA example to text-to-text example.\"\"\"\n",
" return {\n",
" \"inputs\": prefix + ex[\"question\"],\n",
" \"targets\": ex[\"answer\"][\"value\"],\n",
" \"answers\": ex[\"answer\"][\"aliases\"],\n",
" }\n",
"\n",
" return tqa_map(dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"postprocess_fn=tqa_open_preprocessor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Metrics function\n",
"metric_fns=[metrics.bleu]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define Mixtures\n",
"seqio.MixtureRegistry.add(\n",
" \"mix2\",\n",
" [\"task1\", \"task2\"],\n",
" default_rate=seqio.mixing_rate_num_examples\n",
")\n",
"\n",
"# Mixture in your Mixture\n",
"seqio.MixtureRegistry.add(\n",
" \"mix3\",\n",
" [\"mix1\", \"task1\", \"task3\"],\n",
" default_rate=1\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Getting the dataset\n",
"dataset = seqio.get_mixture_or_task(\"mix1\").get_dataset(\n",
" sequence_length={\"inputs\": 256, \"targets\": 128},\n",
" split=\"train\",\n",
" shuffle=True,\n",
" num_epochs=1,\n",
" shard_info=seqio.ShardInfo(index=0, num_shards=10), # load a deterministic subset of the dataset\n",
" use_cached=False, # load from a pre-cached task or process on-the-fly\n",
" seed=42 # deterministic shuffling and (stateless) stochastic ops\n",
")\n",
"\n",
"# Print the first 5 examples.\n",
"for _, ex in zip(range(5), dataset.as_numpy_iterator()):\n",
" print(ex)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### (optional) Offline Caching\n",
"\n",
"For improved performance at load time and avoid redundant computations for commonly used tasks, you can pre-cache your Task with all or part of the preprocessing done in advance of training.\n",
"\n",
"Add:\n",
"\n",
"```python\n",
"seqio.CacheDatasetPlaceholder(required=False)\n",
"```\n",
"\n",
"as one of the steps in your preprocessing pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run as an Apache Beam job on Cloud DataFlow\n",
"cache_tasks_main"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, you are ready to load the cached version of your Task (or Mixture) containing it. \n",
"You will need to add the path to the directory you passed to `--output_cache_dir` via `seqio.add_global_cache_dirs([\"/my/cache/dir\"])`. \n",
"\n",
"Now when you call `task_or_mixture.get_dataset(..., use_cached=True)`, the data will be loaded from the cache directory instead of the raw data source."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## End to end sample"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"DEFAULT_OUTPUT_FEATURES = {\n",
" \"inputs\": seqio.Feature(\n",
" vocabulary=t5.data.get_default_vocabulary(), \n",
" add_eos=True,\n",
" required=False),\n",
" \"targets\": seqio.Feature(\n",
" vocabulary=t5.data.get_default_vocabulary(), \n",
" add_eos=True)\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"@seqio.map_over_dataset\n",
"def summarize(x, article_key, summary_key):\n",
" \"\"\"Convert a summarization dataset to a text2text pair.\n",
"\n",
" For example, say the dataset returns examples of this format:\n",
" {'article': <article>, 'highlights': <summary>}\n",
" If article_key = 'article', summary_key = 'highlights', then the outputs will\n",
" have the format:\n",
" {'inputs': 'summarize': <article>, 'targets': <summary>}\n",
"\n",
" Args:\n",
" x: an example to process.\n",
" article_key: the feature key for the article to summarize.\n",
" summary_key: the feature key for the target summary.\n",
" Returns:\n",
" A preprocessed example with the format listed above.\n",
" \"\"\"\n",
" strs_to_join = ['summarize:', x[article_key]]\n",
" return {\n",
" 'inputs': tf.strings.join(strs_to_join, separator=' '),\n",
" 'targets': x[summary_key],\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"task = seqio.Task(\n",
" name='custom_xsum',\n",
" source=seqio.TFExampleDataSource(\n",
" split_to_filepattern={\n",
" 'validation': 'gs://rl-llm-export-ft/xsum-dataset/validation.tfrecord'\n",
" },\n",
" feature_description={\n",
" 'document': tf.io.FixedLenFeature(shape=(), dtype=tf.string),\n",
" 'summary': tf.io.FixedLenFeature(shape=(), dtype=tf.string)\n",
" }\n",
" ),\n",
" preprocessors=[\n",
" functools.partial(\n",
" summarize,\n",
" article_key=\"document\",\n",
" summary_key=\"summary\"),\n",
" seqio.preprocessors.tokenize,\n",
" seqio.CacheDatasetPlaceholder(),\n",
" seqio.preprocessors.append_eos_after_trim,\n",
" ],\n",
" metric_fns=[metrics.rouge],\n",
" output_features=DEFAULT_OUTPUT_FEATURES\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"task_ds = task.get_dataset(sequence_length={'document':512, 'summary':128}, split='validation').as_numpy_iterator()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'inputs_pretokenized': b'summarize: Each course teaches students, who must be 18 or older when attending, about GCHQ\\'s work to combat cyber-threats and helps them develop their \"cyber-skills\".\\nThe schools are held at four sites in the UK, an increase from two last year.\\nThe expansion was made following high demand for places, GCHQ has said.\\nOne of the schools, Cyber Insiders, will give participants the chance to learn from a range of cybersecurity experts and will be held at Cheltenham from 4 July to 9 September.\\nAnother, named Cyber Exposure, is targeted at students with a \"natural curiosity for technology and problem-solving\".\\nThe Cyber Exposure courses will be held at sites in Scarborough, the Manchester area, and the Thames Valley between 11 July and 19 August.\\n\"It is good that GCHQ is doing this, it increases the number of people that are learning about cybersecurity,\" cryptography expert Prof Mark Ryan, at the University of Birmingham, told the BBC.\\n\"We do have a cybersecurity skills gap where we just aren\\'t training enough people in cybersecurity.\"\\nThe Cheltenham-based school is targeted specifically at first- and second-year students studying computer science, maths, physics or related subjects.\\nThe other courses are open to students of any subject who have five GCSEs, including maths and English, who are also on track to achieve two A-levels at C grade or higher.\\nIn a statement, GCHQ said that work at the summer schools would cover a wide range of technologies.\\n\"Students will learn about GCHQ\\'s role in defending the UK against cyber-threats whilst being paid \\xc2\\xa3250 a week,\" the agency said.\\nApplications are now open at the GCHQ careers website.',\n",
" 'inputs': array([21603, 10, 1698, 503, 3, 11749, 481, 6, 113,\n",
" 398, 36, 507, 42, 2749, 116, 7078, 6, 81,\n",
" 3, 11055, 21447, 31, 7, 161, 12, 4719, 9738,\n",
" 18, 189, 60, 144, 7, 11, 1691, 135, 1344,\n",
" 70, 96, 75, 63, 1152, 18, 7, 10824, 7,\n",
" 1280, 37, 2061, 33, 1213, 44, 662, 1471, 16,\n",
" 8, 1270, 6, 46, 993, 45, 192, 336, 215,\n",
" 5, 37, 5919, 47, 263, 826, 306, 2173, 21,\n",
" 1747, 6, 3, 11055, 21447, 65, 243, 5, 555,\n",
" 13, 8, 2061, 6, 14183, 9014, 52, 7, 6,\n",
" 56, 428, 3008, 8, 1253, 12, 669, 45, 3,\n",
" 9, 620, 13, 28684, 2273, 11, 56, 36, 1213,\n",
" 44, 2556, 40, 324, 1483, 45, 314, 1718, 12,\n",
" 668, 1600, 5, 2351, 6, 2650, 14183, 13471, 4334,\n",
" 6, 19, 7774, 44, 481, 28, 3, 9, 96,\n",
" 14884, 18967, 21, 748, 11, 682, 18, 6065, 53,\n",
" 1280, 37, 14183, 13471, 4334, 2996, 56, 36, 1213,\n",
" 44, 1471, 16, 14586, 12823, 6, 8, 9145, 616,\n",
" 6, 11, 8, 29989, 3460, 344, 850, 1718, 11,\n",
" 957, 1660, 5, 96, 196, 17, 19, 207, 24,\n",
" 3, 11055, 21447, 19, 692, 48, 6, 34, 5386,\n",
" 8, 381, 13, 151, 24, 33, 1036, 81, 28684,\n",
" 976, 17620, 16369, 2205, 7477, 2185, 7826, 6, 44,\n",
" 8, 636, 13, 15922, 6, 1219, 8, 9938, 5,\n",
" 96, 1326, 103, 43, 3, 9, 28684, 1098, 6813,\n",
" 213, 62, 131, 33, 29, 31, 17, 761, 631,\n",
" 151, 16, 28684, 535, 37, 2556, 40, 324, 1483,\n",
" 18, 390, 496, 19, 7774, 3346, 44, 166, 18,\n",
" 11, 511, 18, 1201, 481, 6908, 1218, 2056, 6,\n",
" 7270, 7, 6, 3, 11599, 42, 1341, 7404, 5,\n",
" 37, 119, 2996, 33, 539, 12, 481, 13, 136,\n",
" 1426, 113, 43, 874, 3, 11055, 4132, 7, 6,\n",
" 379, 7270, 7, 11, 1566, 6, 113, 33, 92,\n",
" 30, 1463, 12, 1984, 192, 71, 18, 4563, 7,\n",
" 44, 205, 2769, 42, 1146, 5, 86, 3, 9,\n",
" 2493, 6, 3, 11055, 21447, 243, 24, 161, 44,\n",
" 8, 1248, 2061, 133, 1189, 3, 9, 1148, 620,\n",
" 13, 2896, 5, 96, 13076, 24180, 56, 669, 81,\n",
" 3, 11055, 21447, 31, 7, 1075, 16, 3, 20309,\n",
" 8, 1270, 581, 9738, 18, 189, 60, 144, 7,\n",
" 7096, 7, 17, 271, 1866, 17586, 1752, 3, 9,\n",
" 471, 976, 8, 3193, 243, 5, 15148, 33, 230,\n",
" 539, 44, 8, 3, 11055, 21447, 13325, 475, 5,\n",
" 1], dtype=int32),\n",
" 'targets_pretokenized': b'UK intelligence agency GCHQ has announced it will pay \\xc2\\xa3250 per week to students attending its Cyber Summer Schools this year.',\n",
" 'targets': array([ 1270, 6123, 3193, 3, 11055, 21447, 65, 2162, 34,\n",
" 56, 726, 17586, 1752, 399, 471, 12, 481, 7078,\n",
" 165, 14183, 5550, 13255, 48, 215, 5, 1],\n",
" dtype=int32)}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"next(iter(task_ds))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# SeqIO - Feature Converters"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Tasks (task features): Datasets => Provide model-specific features (e.g.: generic inputs and outputs)\n",
"\n",
"Feature Converters (model features): Transform the model-agnostic features => model-specific features"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Raw data\n",
"\"That is good\\tDas ist gut.\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# After preprocessing from SeqIO\n",
"{\"inputs\": \"translate English to German: That is good.\",\n",
" \"targets\": \"Das ist gut.\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Sample output\n",
"dataset = [{\"inputs\": [7, 8, 5], \"targets\": [8, 4, 9, 3, 1]}]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# FeatureConverter: Converts to model-specific features (padding, packing)\n",
"# Transformer: Encoder - Decoder\n",
"converted_dataset = [{\n",
" \"encoder_input_tokens\": [7, 8, 5, 1, 8, 4, 9, 3, 1, 0],\n",
" \"encoder_segment_ids\": [1, 1, 1, 1, 2, 2, 2, 2, 2, 0],\n",
" \"encoder_positions\": [0, 1, 2, 3, 0, 1, 2, 3, 4, 0],\n",
" \"decoder_target_tokens\": [3, 9, 1, 4, 1, 0, 0],\n",
" \"decoder_input_tokens\": [0, 3, 9, 0, 4, 0, 0],\n",
" \"decoder_loss_weights\": [1, 1, 1, 1, 1, 0, 0],\n",
" \"decoder_positions\": [0, 1, 2, 0, 1, 0, 0],\n",
" \"decoder_segment_ids\": [1, 1, 1, 2, 2, 0, 0],\n",
"}]\n",
"\n",
"# encoder-decoder, decoder-only and encoder-only"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset: tf.data.Dataset = seqio.get_dataset(\n",
" mixture_or_task_name=\"wmt_t2t_ende_v003\",\n",
" task_feature_lengths={\"inputs\": 32, \"targets\": 32},\n",
" dataset_split=\"train\",\n",
" shuffle=True,\n",
" feature_converter=seqio.EncDecFeatureConverter(pack=True)\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment