-
-
Save shug3502/36e33392559c0a3b54534a898f5177e4 to your computer and use it in GitHub Desktop.
TFRecords Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import tensorflow as tf" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def tf_bytes_feature(x):\n", | |
" return tf.train.Feature(bytes_list=tf.train.BytesList(value=[x]))\n", | |
"\n", | |
"def tf_int_feature(x):\n", | |
" return tf.train.Feature(int64_list=tf.train.Int64List(value=[x]))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 1. Create some dummy data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dummy_image = np.random.rand(100, 100, 3) * 255\n", | |
"dummy_image = dummy_image.astype(np.uint8)\n", | |
"\n", | |
"dummy_annotations = np.random.rand(20, 4) * 255\n", | |
"dummy_annotations = dummy_annotations.astype(np.float32)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 2. Serialize into strings" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"raw_image = dummy_image.tostring()\n", | |
"raw_data = dummy_annotations.tostring()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 3. Stuff the data into an example protobuffer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"example = tf.train.Example(features=tf.train.Features(\n", | |
" feature={\n", | |
" 'image/height': tf_int_feature(dummy_image.shape[0]),\n", | |
" 'image/width': tf_int_feature(dummy_image.shape[1]),\n", | |
" 'image/channels': tf_int_feature(dummy_image.shape[2]),\n", | |
" 'image/encoded': tf_bytes_feature(raw_image),\n", | |
" 'boxes/n_boxes': tf_int_feature(dummy_annotations.shape[0]),\n", | |
" 'boxes/encoded': tf_bytes_feature(raw_data)\n", | |
" }\n", | |
"))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 4. Write it into a TFRecords file" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"writer = tf.python_io.TFRecordWriter('dummy_example.tfrecords')\n", | |
"writer.write(example.SerializeToString())\n", | |
"writer.close()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 5. Read it back" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def parse_serialized(serialized_example):\n", | |
"\n", | |
" features = tf.parse_single_example(\n", | |
" serialized_example,\n", | |
" features={\n", | |
" 'image/height': tf.FixedLenFeature([], tf.int64),\n", | |
" 'image/width': tf.FixedLenFeature([], tf.int64),\n", | |
" 'image/channels': tf.FixedLenFeature([], tf.int64),\n", | |
" 'image/encoded': tf.FixedLenFeature([], tf.string),\n", | |
" 'boxes/n_boxes': tf.FixedLenFeature([], tf.int64),\n", | |
" 'boxes/encoded': tf.FixedLenFeature([], tf.string)\n", | |
" })\n", | |
" \n", | |
" height = tf.cast(features['image/height'], tf.int32)\n", | |
" width = tf.cast(features['image/width'], tf.int32)\n", | |
" channels = tf.cast(features['image/channels'], tf.int32)\n", | |
" image = tf.decode_raw(features['image/encoded'], tf.uint8)\n", | |
" tf_image_shape = tf.stack([height, width, channels])\n", | |
" image = tf.reshape(image, tf_image_shape)\n", | |
" \n", | |
" n_boxes = tf.cast(features['boxes/n_boxes'], tf.int32)\n", | |
" boxes = tf.decode_raw(features['boxes/encoded'], tf.float32)\n", | |
" boxes_shape = tf.stack([20, 4])\n", | |
" boxes = tf.reshape(boxes, boxes_shape)\n", | |
"\n", | |
" # Shapes need to be known for batching!\n", | |
" boxes.set_shape([20, 4])\n", | |
" image.set_shape([100, 100, 3])\n", | |
"\n", | |
" return image, boxes" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 6. Use a queurunner as before, with TFRecords this time" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"queue = tf.train.string_input_producer(['dummy_example.tfrecords'])\n", | |
"reader = tf.TFRecordReader()\n", | |
"_, serialized_example = reader.read(queue)\n", | |
"image_tensor, box_tensor = parse_serialized(serialized_example)\n", | |
"\n", | |
"image_tensor, box_tensor = tf.train.shuffle_batch(\n", | |
" [image_tensor, box_tensor], batch_size=1, capacity=30, num_threads=1, min_after_dequeue=10\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"with tf.Session() as sess:\n", | |
" \n", | |
" coord = tf.train.Coordinator()\n", | |
" threads = tf.train.start_queue_runners(coord=coord)\n", | |
"\n", | |
" i = image_tensor.eval()\n", | |
" d = box_tensor.eval()\n", | |
" \n", | |
" coord.request_stop()\n", | |
" coord.join(threads)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.testing.assert_array_equal(i[0], dummy_image)\n", | |
"np.testing.assert_array_equal(d[0], dummy_annotations)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 6. Create a TFRecordDataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"records_path = 'dummy_example.tfrecords'\n", | |
"dataset = tf.data.TFRecordDataset(records_path)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Repeat a desired number of epochs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dataset = dataset.repeat(10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Shuffle the dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dataset = dataset.shuffle(buffer_size=1000)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Decode " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dataset = dataset.map(\n", | |
" lambda rec: parse_serialized(rec),\n", | |
" num_parallel_calls=2\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Batch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dataset = dataset.batch(1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### Prefetch and make iterator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dataset = dataset.prefetch(1)\n", | |
"\n", | |
"iterator = dataset.make_one_shot_iterator()\n", | |
"record = iterator.get_next()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"record" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"with tf.Session() as sess:\n", | |
" rc = sess.run(record)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.testing.assert_array_equal(rc[0][0], dummy_image)\n", | |
"np.testing.assert_array_equal(rc[1][0], dummy_annotations)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 7. Wrap everything into an input function\n", | |
"\n", | |
"For use within e.g. Tensorflow's `Estimator` framework" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def input_fn(records_path, n_epochs=1, batch_size=1, shuffle=True):\n", | |
" dataset = tf.data.TFRecordDataset(records_path)\n", | |
" dataset = dataset.repeat(n_epochs)\n", | |
" \n", | |
" if shuffle:\n", | |
" dataset = dataset.shuffle(buffer_size=1000)\n", | |
" \n", | |
" dataset = dataset.map(\n", | |
" lambda rec: parse_serialized(rec),\n", | |
" num_parallel_calls=2\n", | |
" )\n", | |
" \n", | |
" dataset = dataset.batch(batch_size)\n", | |
" dataset = dataset.prefetch(1)\n", | |
"\n", | |
" iterator = dataset.make_one_shot_iterator()\n", | |
" record = iterator.get_next()\n", | |
" \n", | |
" return record" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"record = input_fn('dummy_example.tfrecords')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"with tf.Session() as sess:\n", | |
" rc = sess.run(record)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.testing.assert_array_equal(rc[0][0], dummy_image)\n", | |
"np.testing.assert_array_equal(rc[1][0], dummy_annotations)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment