Skip to content

Instantly share code, notes, and snippets.

@TzuChiehHung
Last active May 4, 2021 07:48
Show Gist options
  • Save TzuChiehHung/5e5ce60a914eb8db9bef549a2e5819c5 to your computer and use it in GitHub Desktop.
Save TzuChiehHung/5e5ce60a914eb8db9bef549a2e5819c5 to your computer and use it in GitHub Desktop.
[TFRecord dataset shuffle batch example] #tensorflow #python
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"def _float_feature(value):\n",
" \"\"\"Wrapper for inserting float features into Example proto.\"\"\"\n",
" if not isinstance(value, list):\n",
" value = [value]\n",
" return tf.train.Feature(float_list=tf.train.FloatList(value=value))\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# create tfrecords\n",
"for j in range(2):\n",
" writer = tf.python_io.TFRecordWriter('test%d.tfrecords' %j)\n",
"\n",
" for i in range(9):\n",
" features = {\n",
" 'input': _float_feature(10*j+i),\n",
" 'output': _float_feature(10*j+0.1*i),\n",
" }\n",
" example = tf.train.Example(features=tf.train.Features(feature=features))\n",
" writer.write(example.SerializeToString())\n",
" writer.close()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch #1\n",
"(array([ 3., 11., 17.], dtype=float32), array([ 0.30000001, 10.10000038, 10.69999981], dtype=float32))\n",
"(array([ 2., 5., 12.], dtype=float32), array([ 0.2 , 0.5 , 10.19999981], dtype=float32))\n",
"(array([ 1., 0., 4.], dtype=float32), array([ 0.1 , 0. , 0.40000001], dtype=float32))\n",
"(array([ 16., 8., 14.], dtype=float32), array([ 10.60000038, 0.80000001, 10.39999962], dtype=float32))\n",
"(array([ 7., 13., 15.], dtype=float32), array([ 0.69999999, 10.30000019, 10.5 ], dtype=float32))\n",
"(array([ 10., 18., 6.], dtype=float32), array([ 10. , 10.80000019, 0.60000002], dtype=float32))\n",
"End of epoch\n",
"\n",
"Epoch #2\n",
"(array([ 13., 0., 1.], dtype=float32), array([ 10.30000019, 0. , 0.1 ], dtype=float32))\n",
"(array([ 18., 16., 12.], dtype=float32), array([ 10.80000019, 10.60000038, 10.19999981], dtype=float32))\n",
"(array([ 15., 4., 10.], dtype=float32), array([ 10.5 , 0.40000001, 10. ], dtype=float32))\n",
"(array([ 7., 5., 6.], dtype=float32), array([ 0.69999999, 0.5 , 0.60000002], dtype=float32))\n",
"(array([ 3., 17., 2.], dtype=float32), array([ 0.30000001, 10.69999981, 0.2 ], dtype=float32))\n",
"(array([ 11., 8., 14.], dtype=float32), array([ 10.10000038, 0.80000001, 10.39999962], dtype=float32))\n",
"End of epoch\n",
"\n",
"End of training\n"
]
}
],
"source": [
"def parse_function(example):\n",
" tmp = tf.parse_single_example(example, features={\n",
" 'input': tf.FixedLenFeature([], tf.float32),\n",
" 'output': tf.FixedLenFeature([], tf.float32)})\n",
" return tmp['input'], tmp['output']\n",
"\n",
"\n",
"files = tf.data.Dataset.list_files('test*.tfrecords')\n",
"dataset = files.interleave(tf.data.TFRecordDataset, 1)\n",
"dataset = dataset.shuffle(buffer_size=18)\n",
"dataset = dataset.map(parse_function, num_parallel_calls=32) # Parse the record into tensors.\n",
"dataset = dataset.batch(3)\n",
"dataset = dataset.prefetch(buffer_size= 3)\n",
"dataset = dataset.repeat(1) # Repeat the input indefinitely.\n",
"iterator = dataset.make_initializable_iterator()\n",
"\n",
"next_batch = iterator.get_next()\n",
"\n",
"num_epochs = 2\n",
"\n",
"with tf.Session() as sess:\n",
" for i in range(2):\n",
" # Resets the iterator at the beginning of an epoch.\n",
" print('Epoch #{:d}'.format(i+1))\n",
" sess.run(iterator.initializer)\n",
"\n",
" try:\n",
" while True:\n",
" print(sess.run(next_batch))\n",
" except tf.errors.OutOfRangeError:\n",
" # This will be raised when you reach the end of an epoch (i.e. the\n",
" # iterator has no more elements).\n",
" pass \n",
"\n",
" # Perform any end-of-epoch computation here.\n",
" print('End of epoch\\n')\n",
"\n",
" print('End of training')\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment