Skip to content

Instantly share code, notes, and snippets.

@leiterenato
Created November 14, 2022 16:51
Show Gist options
  • Save leiterenato/ffa4ff7efdcf0ce238ba780757688f9a to your computer and use it in GitHub Desktop.
Save leiterenato/ffa4ff7efdcf0ce238ba780757688f9a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import seqio\n",
"import json"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create TF Records\n",
"\n",
"I gave an example to write 1 record. \n",
"You just need to iterate through the JSON load to write the other samples."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# The following functions can be used to convert a value to a type compatible\n",
"# with tf.train.Example.\n",
"\n",
"def _bytes_feature(value):\n",
" \"\"\"Returns a bytes_list from a string / byte.\"\"\"\n",
" if isinstance(value, type(tf.constant(0))):\n",
" value = value.numpy() # BytesList won't unpack a string from an EagerTensor.\n",
" return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Create a dictionary with features that may be relevant.\n",
"def sample_example(one_sample: dict):\n",
" feature = {\n",
" 'id': _bytes_feature(one_sample['id'].encode('UTF-8')),\n",
" 'context': _bytes_feature(one_sample['context'].encode('UTF-8')),\n",
" 'question': _bytes_feature(one_sample['question'].encode('UTF-8')),\n",
" 'answers': _bytes_feature(one_sample['answers']['text'][0].encode('UTF-8'))\n",
" }\n",
"\n",
" return tf.train.Example(features=tf.train.Features(feature=feature))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is an example\n",
"with open('squad1.json', 'r') as fp:\n",
" sample = json.load(fp)\n",
"\n",
"record_file = 'samples.tfrecords'\n",
"\n",
"with tf.io.TFRecordWriter(record_file) as writer:\n",
" tf_example = sample_example(sample[0])\n",
" writer.write(tf_example.SerializeToString())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define data preprocessing and test the dataset\n",
"\n",
"I copied and paste the squad functions here."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Create a description of the features.\n",
"feature_description = {\n",
" 'id': tf.io.FixedLenFeature([], tf.string),\n",
" 'context': tf.io.FixedLenFeature([], tf.string),\n",
" 'question': tf.io.FixedLenFeature([], tf.string),\n",
" 'answers': tf.io.FixedLenFeature([], tf.string),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"source = seqio.TFExampleDataSource(\n",
" split_to_filepattern={'train': 'samples.tfrecords'},\n",
" feature_description=feature_description\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds = source.get_dataset(split='train')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in ds.as_numpy_iterator():\n",
" print(i)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create a seqio task"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
"import seqio\n",
"import t5.data\n",
"from t5.data import postprocessors\n",
"from t5.data import preprocessors\n",
"from t5.evaluation import metrics"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def _string_join(lst):\n",
" # Join on space, but collapse consecutive spaces.\n",
" out = tf.strings.join(lst, separator=' ')\n",
" return tf.strings.regex_replace(out, r'\\s+', ' ')"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def _pad_punctuation(text):\n",
" \"\"\"Adds spaces around punctuation.\"\"\"\n",
" # Add space around punctuation.\n",
" text = tf.strings.regex_replace(text, r'([[:punct:]])', r' \\1 ')\n",
" # Collapse consecutive whitespace into one space.\n",
" text = tf.strings.regex_replace(text, r'\\s+', ' ')\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"@seqio.map_over_dataset\n",
"def squad(x, include_context=True):\n",
" \"\"\"Convert SQuAD examples to a text2text pair.\n",
" SQuAD produces examples with this form:\n",
" {'id': <id>, context': <article>, 'question': <question>,\n",
" 'answers': { 'text': [<n answers>] }}\n",
" This function will return examples of the format:\n",
" {'inputs': 'question: <question> context: <article>',\n",
" 'targets': '<answer_0>',\n",
" 'id': <id>, 'question': <question>, 'context': <context>,\n",
" 'answers': [<n answers>]},\n",
" Args:\n",
" x: an example to process.\n",
" include_context: a boolean\n",
" Returns:\n",
" A preprocessed example with the format listed above.\n",
" \"\"\"\n",
" a = _pad_punctuation(x['answers'])\n",
" q = _pad_punctuation(x['question'])\n",
" c = _pad_punctuation(x['context'])\n",
" if include_context:\n",
" inputs = _string_join(['question:', q, 'context:', c])\n",
" else:\n",
" inputs = _string_join(['squad trivia question:', q])\n",
" return {\n",
" 'inputs': inputs,\n",
" 'targets': a,\n",
" 'id': x['id'],\n",
" 'context': c,\n",
" 'question': q,\n",
" 'answers': a\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"DEFAULT_OUTPUT_FEATURES = {\n",
" \"inputs\": seqio.Feature(\n",
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True,\n",
" required=False),\n",
" \"targets\": seqio.Feature(\n",
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True)\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"task = seqio.Task(\n",
" \"squad_v010_allanswers\",\n",
" source=seqio.TFExampleDataSource(\n",
" split_to_filepattern={'train': 'samples.tfrecords'},\n",
" feature_description=feature_description\n",
" ),\n",
" preprocessors=[\n",
" functools.partial(squad),\n",
" seqio.preprocessors.tokenize,\n",
" seqio.CacheDatasetPlaceholder(),\n",
" seqio.preprocessors.append_eos_after_trim,\n",
" ],\n",
" postprocess_fn=postprocessors.qa,\n",
" metric_fns=[metrics.squad],\n",
" output_features=DEFAULT_OUTPUT_FEATURES)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"ds = task.get_dataset(split='train', sequence_length={'inputs': 128, 'targets': 128})"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'inputs_pretokenized': b'question: Which NFL team represented the NFC at Super Bowl 50 ? context: Super Bowl 50 was an American football game to determine the champion of the National Football League ( NFL ) for the 2015 season . The American Football Conference ( AFC ) champion Denver Broncos defeated the National Football Conference ( NFC ) champion Carolina Panthers 24\\xe2\\x80\\x9310 to earn their third Super Bowl title . The game was played on February 7 , 2016 , at Levi \\' s Stadium in the San Francisco Bay Area at Santa Clara , California . As this was the 50th Super Bowl , the league emphasized the \" golden anniversary \" with various gold - themed initiatives , as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals ( under which the game would have been known as \" Super Bowl L \" ) , so that the logo could prominently feature the Arabic numerals 50 . ', 'inputs': array([ 822, 10, 4073, 10439, 372, 7283, 8, 445, 5390,\n",
" 44, 2011, 9713, 943, 3, 58, 2625, 10, 2011,\n",
" 9713, 943, 47, 46, 797, 3370, 467, 12, 2082,\n",
" 8, 6336, 13, 8, 868, 10929, 3815, 41, 10439,\n",
" 3, 61, 21, 8, 1230, 774, 3, 5, 37,\n",
" 797, 10929, 4379, 41, 71, 5390, 3, 61, 6336,\n",
" 12154, 4027, 29, 509, 7, 17025, 8, 868, 10929,\n",
" 4379, 41, 445, 5390, 3, 61, 6336, 5089, 21149,\n",
" 7, 997, 104, 1714, 12, 3807, 70, 1025, 2011,\n",
" 9713, 2233, 3, 5, 37, 467, 47, 1944, 30,\n",
" 2083, 489, 3, 6, 1421, 3, 6, 44, 16755,\n",
" 3, 31, 3, 7, 12750, 16, 8, 1051, 5901,\n",
" 2474, 5690, 44, 4625, 9908, 9, 3, 6, 1826,\n",
" 3, 5, 282, 48, 47, 8, 943, 189, 2011,\n",
" 9713, 1], dtype=int32), 'targets_pretokenized': b'Carolina Panthers', 'targets': array([ 5089, 21149, 7, 1], dtype=int32), 'id': b'56be4db0acb8001400a502ed', 'context': b'Super Bowl 50 was an American football game to determine the champion of the National Football League ( NFL ) for the 2015 season . The American Football Conference ( AFC ) champion Denver Broncos defeated the National Football Conference ( NFC ) champion Carolina Panthers 24\\xe2\\x80\\x9310 to earn their third Super Bowl title . The game was played on February 7 , 2016 , at Levi \\' s Stadium in the San Francisco Bay Area at Santa Clara , California . As this was the 50th Super Bowl , the league emphasized the \" golden anniversary \" with various gold - themed initiatives , as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals ( under which the game would have been known as \" Super Bowl L \" ) , so that the logo could prominently feature the Arabic numerals 50 . ', 'question': b'Which NFL team represented the NFC at Super Bowl 50 ? ', 'answers': b'Carolina Panthers'}\n"
]
}
],
"source": [
"for i in ds.as_numpy_iterator():\n",
" print(i)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.7.12 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment