Skip to content

Instantly share code, notes, and snippets.

@shug3502
Forked from aoikonomop/tfrecord_example.ipynb
Created March 21, 2018 17:14
Show Gist options
  • Save shug3502/36e33392559c0a3b54534a898f5177e4 to your computer and use it in GitHub Desktop.
Save shug3502/36e33392559c0a3b54534a898f5177e4 to your computer and use it in GitHub Desktop.
TFRecords Example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tf_bytes_feature(x):\n",
" return tf.train.Feature(bytes_list=tf.train.BytesList(value=[x]))\n",
"\n",
"def tf_int_feature(x):\n",
" return tf.train.Feature(int64_list=tf.train.Int64List(value=[x]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Create some dummy data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dummy_image = np.random.rand(100, 100, 3) * 255\n",
"dummy_image = dummy_image.astype(np.uint8)\n",
"\n",
"dummy_annotations = np.random.rand(20, 4) * 255\n",
"dummy_annotations = dummy_annotations.astype(np.float32)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Serialize into strings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"raw_image = dummy_image.tostring()\n",
"raw_data = dummy_annotations.tostring()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Stuff the data into an example protobuffer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"example = tf.train.Example(features=tf.train.Features(\n",
" feature={\n",
" 'image/height': tf_int_feature(dummy_image.shape[0]),\n",
" 'image/width': tf_int_feature(dummy_image.shape[1]),\n",
" 'image/channels': tf_int_feature(dummy_image.shape[2]),\n",
" 'image/encoded': tf_bytes_feature(raw_image),\n",
" 'boxes/n_boxes': tf_int_feature(dummy_annotations.shape[0]),\n",
" 'boxes/encoded': tf_bytes_feature(raw_data)\n",
" }\n",
"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Write it into a TFRecords file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"writer = tf.python_io.TFRecordWriter('dummy_example.tfrecords')\n",
"writer.write(example.SerializeToString())\n",
"writer.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5. Read it back"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def parse_serialized(serialized_example):\n",
"\n",
" features = tf.parse_single_example(\n",
" serialized_example,\n",
" features={\n",
" 'image/height': tf.FixedLenFeature([], tf.int64),\n",
" 'image/width': tf.FixedLenFeature([], tf.int64),\n",
" 'image/channels': tf.FixedLenFeature([], tf.int64),\n",
" 'image/encoded': tf.FixedLenFeature([], tf.string),\n",
" 'boxes/n_boxes': tf.FixedLenFeature([], tf.int64),\n",
" 'boxes/encoded': tf.FixedLenFeature([], tf.string)\n",
" })\n",
" \n",
" height = tf.cast(features['image/height'], tf.int32)\n",
" width = tf.cast(features['image/width'], tf.int32)\n",
" channels = tf.cast(features['image/channels'], tf.int32)\n",
" image = tf.decode_raw(features['image/encoded'], tf.uint8)\n",
" tf_image_shape = tf.stack([height, width, channels])\n",
" image = tf.reshape(image, tf_image_shape)\n",
" \n",
" n_boxes = tf.cast(features['boxes/n_boxes'], tf.int32)\n",
" boxes = tf.decode_raw(features['boxes/encoded'], tf.float32)\n",
" boxes_shape = tf.stack([20, 4])\n",
" boxes = tf.reshape(boxes, boxes_shape)\n",
"\n",
" # Shapes need to be known for batching!\n",
" boxes.set_shape([20, 4])\n",
" image.set_shape([100, 100, 3])\n",
"\n",
" return image, boxes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6. Use a queurunner as before, with TFRecords this time"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"queue = tf.train.string_input_producer(['dummy_example.tfrecords'])\n",
"reader = tf.TFRecordReader()\n",
"_, serialized_example = reader.read(queue)\n",
"image_tensor, box_tensor = parse_serialized(serialized_example)\n",
"\n",
"image_tensor, box_tensor = tf.train.shuffle_batch(\n",
" [image_tensor, box_tensor], batch_size=1, capacity=30, num_threads=1, min_after_dequeue=10\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"with tf.Session() as sess:\n",
" \n",
" coord = tf.train.Coordinator()\n",
" threads = tf.train.start_queue_runners(coord=coord)\n",
"\n",
" i = image_tensor.eval()\n",
" d = box_tensor.eval()\n",
" \n",
" coord.request_stop()\n",
" coord.join(threads)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.testing.assert_array_equal(i[0], dummy_image)\n",
"np.testing.assert_array_equal(d[0], dummy_annotations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6. Create a TFRecordDataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"records_path = 'dummy_example.tfrecords'\n",
"dataset = tf.data.TFRecordDataset(records_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Repeat a desired number of epochs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = dataset.repeat(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Shuffle the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = dataset.shuffle(buffer_size=1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Decode "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = dataset.map(\n",
" lambda rec: parse_serialized(rec),\n",
" num_parallel_calls=2\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Batch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = dataset.batch(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Prefetch and make iterator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = dataset.prefetch(1)\n",
"\n",
"iterator = dataset.make_one_shot_iterator()\n",
"record = iterator.get_next()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"record"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"with tf.Session() as sess:\n",
" rc = sess.run(record)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.testing.assert_array_equal(rc[0][0], dummy_image)\n",
"np.testing.assert_array_equal(rc[1][0], dummy_annotations)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 7. Wrap everything into an input function\n",
"\n",
"For use within e.g. Tensorflow's `Estimator` framework"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def input_fn(records_path, n_epochs=1, batch_size=1, shuffle=True):\n",
" dataset = tf.data.TFRecordDataset(records_path)\n",
" dataset = dataset.repeat(n_epochs)\n",
" \n",
" if shuffle:\n",
" dataset = dataset.shuffle(buffer_size=1000)\n",
" \n",
" dataset = dataset.map(\n",
" lambda rec: parse_serialized(rec),\n",
" num_parallel_calls=2\n",
" )\n",
" \n",
" dataset = dataset.batch(batch_size)\n",
" dataset = dataset.prefetch(1)\n",
"\n",
" iterator = dataset.make_one_shot_iterator()\n",
" record = iterator.get_next()\n",
" \n",
" return record"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"record = input_fn('dummy_example.tfrecords')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with tf.Session() as sess:\n",
" rc = sess.run(record)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.testing.assert_array_equal(rc[0][0], dummy_image)\n",
"np.testing.assert_array_equal(rc[1][0], dummy_annotations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment