Last active
May 4, 2021 07:48
-
-
Save TzuChiehHung/5e5ce60a914eb8db9bef549a2e5819c5 to your computer and use it in GitHub Desktop.
[TFRecord dataset shuffle batch example] #tensorflow #python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"\n", | |
"def _float_feature(value):\n", | |
" \"\"\"Wrapper for inserting float features into Example proto.\"\"\"\n", | |
" if not isinstance(value, list):\n", | |
" value = [value]\n", | |
" return tf.train.Feature(float_list=tf.train.FloatList(value=value))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# create tfrecords\n", | |
"for j in range(2):\n", | |
" writer = tf.python_io.TFRecordWriter('test%d.tfrecords' %j)\n", | |
"\n", | |
" for i in range(9):\n", | |
" features = {\n", | |
" 'input': _float_feature(10*j+i),\n", | |
" 'output': _float_feature(10*j+0.1*i),\n", | |
" }\n", | |
" example = tf.train.Example(features=tf.train.Features(feature=features))\n", | |
" writer.write(example.SerializeToString())\n", | |
" writer.close()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch #1\n", | |
"(array([ 3., 11., 17.], dtype=float32), array([ 0.30000001, 10.10000038, 10.69999981], dtype=float32))\n", | |
"(array([ 2., 5., 12.], dtype=float32), array([ 0.2 , 0.5 , 10.19999981], dtype=float32))\n", | |
"(array([ 1., 0., 4.], dtype=float32), array([ 0.1 , 0. , 0.40000001], dtype=float32))\n", | |
"(array([ 16., 8., 14.], dtype=float32), array([ 10.60000038, 0.80000001, 10.39999962], dtype=float32))\n", | |
"(array([ 7., 13., 15.], dtype=float32), array([ 0.69999999, 10.30000019, 10.5 ], dtype=float32))\n", | |
"(array([ 10., 18., 6.], dtype=float32), array([ 10. , 10.80000019, 0.60000002], dtype=float32))\n", | |
"End of epoch\n", | |
"\n", | |
"Epoch #2\n", | |
"(array([ 13., 0., 1.], dtype=float32), array([ 10.30000019, 0. , 0.1 ], dtype=float32))\n", | |
"(array([ 18., 16., 12.], dtype=float32), array([ 10.80000019, 10.60000038, 10.19999981], dtype=float32))\n", | |
"(array([ 15., 4., 10.], dtype=float32), array([ 10.5 , 0.40000001, 10. ], dtype=float32))\n", | |
"(array([ 7., 5., 6.], dtype=float32), array([ 0.69999999, 0.5 , 0.60000002], dtype=float32))\n", | |
"(array([ 3., 17., 2.], dtype=float32), array([ 0.30000001, 10.69999981, 0.2 ], dtype=float32))\n", | |
"(array([ 11., 8., 14.], dtype=float32), array([ 10.10000038, 0.80000001, 10.39999962], dtype=float32))\n", | |
"End of epoch\n", | |
"\n", | |
"End of training\n" | |
] | |
} | |
], | |
"source": [ | |
"def parse_function(example):\n", | |
" tmp = tf.parse_single_example(example, features={\n", | |
" 'input': tf.FixedLenFeature([], tf.float32),\n", | |
" 'output': tf.FixedLenFeature([], tf.float32)})\n", | |
" return tmp['input'], tmp['output']\n", | |
"\n", | |
"\n", | |
"files = tf.data.Dataset.list_files('test*.tfrecords')\n", | |
"dataset = files.interleave(tf.data.TFRecordDataset, 1)\n", | |
"dataset = dataset.shuffle(buffer_size=18)\n", | |
"dataset = dataset.map(parse_function, num_parallel_calls=32) # Parse the record into tensors.\n", | |
"dataset = dataset.batch(3)\n", | |
"dataset = dataset.prefetch(buffer_size= 3)\n", | |
"dataset = dataset.repeat(1) # Repeat the input indefinitely.\n", | |
"iterator = dataset.make_initializable_iterator()\n", | |
"\n", | |
"next_batch = iterator.get_next()\n", | |
"\n", | |
"num_epochs = 2\n", | |
"\n", | |
"with tf.Session() as sess:\n", | |
" for i in range(2):\n", | |
" # Resets the iterator at the beginning of an epoch.\n", | |
" print('Epoch #{:d}'.format(i+1))\n", | |
" sess.run(iterator.initializer)\n", | |
"\n", | |
" try:\n", | |
" while True:\n", | |
" print(sess.run(next_batch))\n", | |
" except tf.errors.OutOfRangeError:\n", | |
" # This will be raised when you reach the end of an epoch (i.e. the\n", | |
" # iterator has no more elements).\n", | |
" pass \n", | |
"\n", | |
" # Perform any end-of-epoch computation here.\n", | |
" print('End of epoch\\n')\n", | |
"\n", | |
" print('End of training')\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment