Created
November 14, 2022 16:51
-
-
Save leiterenato/ffa4ff7efdcf0ce238ba780757688f9a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import seqio\n", | |
"import json" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Create TF Records\n", | |
"\n", | |
"I gave an example to write 1 record. \n", | |
"You just need to iterate through the JSON load to write the other samples." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# The following functions can be used to convert a value to a type compatible\n", | |
"# with tf.train.Example.\n", | |
"\n", | |
"def _bytes_feature(value):\n", | |
" \"\"\"Returns a bytes_list from a string / byte.\"\"\"\n", | |
" if isinstance(value, type(tf.constant(0))):\n", | |
" value = value.numpy() # BytesList won't unpack a string from an EagerTensor.\n", | |
" return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Create a dictionary with features that may be relevant.\n", | |
"def sample_example(one_sample: dict):\n", | |
" feature = {\n", | |
" 'id': _bytes_feature(one_sample['id'].encode('UTF-8')),\n", | |
" 'context': _bytes_feature(one_sample['context'].encode('UTF-8')),\n", | |
" 'question': _bytes_feature(one_sample['question'].encode('UTF-8')),\n", | |
" 'answers': _bytes_feature(one_sample['answers']['text'][0].encode('UTF-8'))\n", | |
" }\n", | |
"\n", | |
" return tf.train.Example(features=tf.train.Features(feature=feature))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# This is an example\n", | |
"with open('squad1.json', 'r') as fp:\n", | |
" sample = json.load(fp)\n", | |
"\n", | |
"record_file = 'samples.tfrecords'\n", | |
"\n", | |
"with tf.io.TFRecordWriter(record_file) as writer:\n", | |
" tf_example = sample_example(sample[0])\n", | |
" writer.write(tf_example.SerializeToString())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Define data preprocessing and test the dataset\n", | |
"\n", | |
"I copied and paste the squad functions here." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Create a description of the features.\n", | |
"feature_description = {\n", | |
" 'id': tf.io.FixedLenFeature([], tf.string),\n", | |
" 'context': tf.io.FixedLenFeature([], tf.string),\n", | |
" 'question': tf.io.FixedLenFeature([], tf.string),\n", | |
" 'answers': tf.io.FixedLenFeature([], tf.string),\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"source = seqio.TFExampleDataSource(\n", | |
" split_to_filepattern={'train': 'samples.tfrecords'},\n", | |
" feature_description=feature_description\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ds = source.get_dataset(split='train')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for i in ds.as_numpy_iterator():\n", | |
" print(i)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Create a seqio task" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import functools\n", | |
"import seqio\n", | |
"import t5.data\n", | |
"from t5.data import postprocessors\n", | |
"from t5.data import preprocessors\n", | |
"from t5.evaluation import metrics" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _string_join(lst):\n", | |
" # Join on space, but collapse consecutive spaces.\n", | |
" out = tf.strings.join(lst, separator=' ')\n", | |
" return tf.strings.regex_replace(out, r'\\s+', ' ')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _pad_punctuation(text):\n", | |
" \"\"\"Adds spaces around punctuation.\"\"\"\n", | |
" # Add space around punctuation.\n", | |
" text = tf.strings.regex_replace(text, r'([[:punct:]])', r' \\1 ')\n", | |
" # Collapse consecutive whitespace into one space.\n", | |
" text = tf.strings.regex_replace(text, r'\\s+', ' ')\n", | |
" return text" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@seqio.map_over_dataset\n", | |
"def squad(x, include_context=True):\n", | |
" \"\"\"Convert SQuAD examples to a text2text pair.\n", | |
" SQuAD produces examples with this form:\n", | |
" {'id': <id>, context': <article>, 'question': <question>,\n", | |
" 'answers': { 'text': [<n answers>] }}\n", | |
" This function will return examples of the format:\n", | |
" {'inputs': 'question: <question> context: <article>',\n", | |
" 'targets': '<answer_0>',\n", | |
" 'id': <id>, 'question': <question>, 'context': <context>,\n", | |
" 'answers': [<n answers>]},\n", | |
" Args:\n", | |
" x: an example to process.\n", | |
" include_context: a boolean\n", | |
" Returns:\n", | |
" A preprocessed example with the format listed above.\n", | |
" \"\"\"\n", | |
" a = _pad_punctuation(x['answers'])\n", | |
" q = _pad_punctuation(x['question'])\n", | |
" c = _pad_punctuation(x['context'])\n", | |
" if include_context:\n", | |
" inputs = _string_join(['question:', q, 'context:', c])\n", | |
" else:\n", | |
" inputs = _string_join(['squad trivia question:', q])\n", | |
" return {\n", | |
" 'inputs': inputs,\n", | |
" 'targets': a,\n", | |
" 'id': x['id'],\n", | |
" 'context': c,\n", | |
" 'question': q,\n", | |
" 'answers': a\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"DEFAULT_OUTPUT_FEATURES = {\n", | |
" \"inputs\": seqio.Feature(\n", | |
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True,\n", | |
" required=False),\n", | |
" \"targets\": seqio.Feature(\n", | |
" vocabulary=t5.data.get_default_vocabulary(), add_eos=True)\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"task = seqio.Task(\n", | |
" \"squad_v010_allanswers\",\n", | |
" source=seqio.TFExampleDataSource(\n", | |
" split_to_filepattern={'train': 'samples.tfrecords'},\n", | |
" feature_description=feature_description\n", | |
" ),\n", | |
" preprocessors=[\n", | |
" functools.partial(squad),\n", | |
" seqio.preprocessors.tokenize,\n", | |
" seqio.CacheDatasetPlaceholder(),\n", | |
" seqio.preprocessors.append_eos_after_trim,\n", | |
" ],\n", | |
" postprocess_fn=postprocessors.qa,\n", | |
" metric_fns=[metrics.squad],\n", | |
" output_features=DEFAULT_OUTPUT_FEATURES)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ds = task.get_dataset(split='train', sequence_length={'inputs': 128, 'targets': 128})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'inputs_pretokenized': b'question: Which NFL team represented the NFC at Super Bowl 50 ? context: Super Bowl 50 was an American football game to determine the champion of the National Football League ( NFL ) for the 2015 season . The American Football Conference ( AFC ) champion Denver Broncos defeated the National Football Conference ( NFC ) champion Carolina Panthers 24\\xe2\\x80\\x9310 to earn their third Super Bowl title . The game was played on February 7 , 2016 , at Levi \\' s Stadium in the San Francisco Bay Area at Santa Clara , California . As this was the 50th Super Bowl , the league emphasized the \" golden anniversary \" with various gold - themed initiatives , as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals ( under which the game would have been known as \" Super Bowl L \" ) , so that the logo could prominently feature the Arabic numerals 50 . ', 'inputs': array([ 822, 10, 4073, 10439, 372, 7283, 8, 445, 5390,\n", | |
" 44, 2011, 9713, 943, 3, 58, 2625, 10, 2011,\n", | |
" 9713, 943, 47, 46, 797, 3370, 467, 12, 2082,\n", | |
" 8, 6336, 13, 8, 868, 10929, 3815, 41, 10439,\n", | |
" 3, 61, 21, 8, 1230, 774, 3, 5, 37,\n", | |
" 797, 10929, 4379, 41, 71, 5390, 3, 61, 6336,\n", | |
" 12154, 4027, 29, 509, 7, 17025, 8, 868, 10929,\n", | |
" 4379, 41, 445, 5390, 3, 61, 6336, 5089, 21149,\n", | |
" 7, 997, 104, 1714, 12, 3807, 70, 1025, 2011,\n", | |
" 9713, 2233, 3, 5, 37, 467, 47, 1944, 30,\n", | |
" 2083, 489, 3, 6, 1421, 3, 6, 44, 16755,\n", | |
" 3, 31, 3, 7, 12750, 16, 8, 1051, 5901,\n", | |
" 2474, 5690, 44, 4625, 9908, 9, 3, 6, 1826,\n", | |
" 3, 5, 282, 48, 47, 8, 943, 189, 2011,\n", | |
" 9713, 1], dtype=int32), 'targets_pretokenized': b'Carolina Panthers', 'targets': array([ 5089, 21149, 7, 1], dtype=int32), 'id': b'56be4db0acb8001400a502ed', 'context': b'Super Bowl 50 was an American football game to determine the champion of the National Football League ( NFL ) for the 2015 season . The American Football Conference ( AFC ) champion Denver Broncos defeated the National Football Conference ( NFC ) champion Carolina Panthers 24\\xe2\\x80\\x9310 to earn their third Super Bowl title . The game was played on February 7 , 2016 , at Levi \\' s Stadium in the San Francisco Bay Area at Santa Clara , California . As this was the 50th Super Bowl , the league emphasized the \" golden anniversary \" with various gold - themed initiatives , as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals ( under which the game would have been known as \" Super Bowl L \" ) , so that the logo could prominently feature the Arabic numerals 50 . ', 'question': b'Which NFL team represented the NFC at Super Bowl 50 ? ', 'answers': b'Carolina Panthers'}\n" | |
] | |
} | |
], | |
"source": [ | |
"for i in ds.as_numpy_iterator():\n", | |
" print(i)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3.7.12 ('base')", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.12" | |
}, | |
"orig_nbformat": 4, | |
"vscode": { | |
"interpreter": { | |
"hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment