Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active August 25, 2022 05:10
Show Gist options
  • Save sayakpaul/f7d5cc312cd01cb31098fad3fd9c6b59 to your computer and use it in GitHub Desktop.
Save sayakpaul/f7d5cc312cd01cb31098fad3fd9c6b59 to your computer and use it in GitHub Desktop.
Demonstrates the compatibility between Feather and TensorFlow with a real dataset.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "9bbe93c8",
"metadata": {},
"source": [
"## Flowers dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6827b5db",
"metadata": {},
"outputs": [],
"source": [
"!wget -q https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz -O flower_photos.tgz\n",
"!tar xf flower_photos.tgz"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d411bfdc",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
"import glob\n",
"import os"
]
},
{
"cell_type": "markdown",
"id": "1fb3f64b",
"metadata": {},
"source": [
"## Collect image paths and derive labels"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "291ee5c0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(['flower_photos/roses/16209331331_343c899d38.jpg',\n",
" 'flower_photos/roses/5777669976_a205f61e5b.jpg',\n",
" 'flower_photos/roses/4860145119_b1c3cbaa4e_n.jpg',\n",
" 'flower_photos/roses/15011625580_7974c44bce.jpg',\n",
" 'flower_photos/roses/17953368844_be3d18cf30_m.jpg'],\n",
" 3670)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"image_paths = glob.glob(\"flower_photos/*/*.jpg\")\n",
"image_paths[:5], len(image_paths)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8a69e7a8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['roses', 'roses', 'roses', 'roses', 'roses']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"all_labels = list(map(lambda x: x.split(\"/\")[1], image_paths))\n",
"all_labels[:5]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "055f907b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"unique_labels = sorted(set(all_labels))\n",
"label2_id = {label: idx for idx, label in enumerate(unique_labels)}\n",
"label2_id"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b74fcedc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[2, 2, 2, 2, 2]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"all_integer_labels = list(map(lambda x: label2_id.get(x), all_labels))\n",
"all_integer_labels[:5]"
]
},
{
"cell_type": "markdown",
"id": "0673d274",
"metadata": {},
"source": [
"## Write into shards of Feather files"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c08b9025",
"metadata": {},
"outputs": [],
"source": [
"from pyarrow.feather import write_feather\n",
"import pyarrow as pa\n",
"import tqdm\n",
"import math"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "51c1f0f6",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 1000\n",
"chunk_size = 1000"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "266f50a9",
"metadata": {},
"outputs": [],
"source": [
"def read_image(path):\n",
" image = tf.io.read_file(path)\n",
" image = tf.image.decode_png(image, 3)\n",
" image = tf.image.resize(image, (64, 64))\n",
" image = tf.cast(image, tf.uint8).numpy()\n",
" return image"
]
},
{
"cell_type": "markdown",
"id": "9d72c3eb",
"metadata": {},
"source": [
"There is a potential caveat. We're having to resize the images to a uniform resolution fore serializing. There's probably support for not having to do it. I haven't figured it out yet. I kept the resizing resolution intentionally low. "
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f7e8dd25",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/z_/d29z43w90kz6f4kbzv5c9m9r0000gn/T/ipykernel_73443/2664877068.py:9: TqdmDeprecationWarning: Please use `tqdm.notebook.trange` instead of `tqdm.tnrange`\n",
" for step in tqdm.tnrange(int(math.ceil(len(image_paths) / batch_size))):\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8622d78ab209454bac3d34ceabfb170a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total images written: 1000.\n",
"Total images written: 2000.\n",
"Total images written: 3000.\n",
"Total images written: 3670.\n"
]
}
],
"source": [
"# Taken from (with some mods):\n",
"\n",
"# https://towardsdatascience.com/data-formats-for-training-in-tensorflow-parquet-petastorm-feather-and-more-e55179eeeb72\n",
"\n",
"# Can be done in a distributed manner.\n",
"\n",
"total_images_written = 0\n",
"\n",
"for step in tqdm.tnrange(int(math.ceil(len(image_paths) / batch_size))):\n",
" batch_image_paths = image_paths[step * batch_size : (step + 1) * batch_size]\n",
" batch_image_labels = all_integer_labels[step * batch_size : (step + 1) * batch_size]\n",
"\n",
" data = [read_image(path).flatten() for path in batch_image_paths]\n",
" labels = batch_image_labels\n",
" table = pa.Table.from_arrays([data, labels], [\"data\", \"labels\"])\n",
" write_feather(table, f\"/tmp/flowers_feather_{step}.feather\", chunksize=chunk_size)\n",
" total_images_written += len(batch_image_paths)\n",
" print(f\"Total images written: {total_images_written}.\")\n",
"\n",
" del data, labels"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0f785479",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-rw-r--r-- 1 sayakpaul wheel 144M Jun 23 17:54 /tmp/df.feather\r\n",
"-rw-r--r-- 1 sayakpaul wheel 11M Jun 23 18:02 /tmp/flowers_feather_0.feather\r\n",
"-rw-r--r-- 1 sayakpaul wheel 11M Jun 23 18:02 /tmp/flowers_feather_1.feather\r\n",
"-rw-r--r-- 1 sayakpaul wheel 11M Jun 23 18:02 /tmp/flowers_feather_2.feather\r\n",
"-rw-r--r-- 1 sayakpaul wheel 7.7M Jun 23 18:02 /tmp/flowers_feather_3.feather\r\n"
]
}
],
"source": [
"!ls -lh /tmp/*.feather"
]
},
{
"cell_type": "markdown",
"id": "05a58126",
"metadata": {},
"source": [
"## Loading the Feather files into `tf.data`"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a66a9b02",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow_io.arrow as arrow_io"
]
},
{
"cell_type": "markdown",
"id": "85bbee91",
"metadata": {},
"source": [
"### From single Feather file"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "02acf656",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-06-23 18:06:29.670200: E tensorflow/core/framework/dataset.cc:580] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:06:29.670224: E tensorflow/core/framework/dataset.cc:584] UNIMPLEMENTED: Cannot merge options for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n"
]
}
],
"source": [
"dataset = arrow_io.ArrowFeatherDataset(\n",
" [\"/tmp/flowers_feather_0.feather\"],\n",
" columns=(0, 1),\n",
" output_types=(tf.uint8, tf.int64),\n",
" output_shapes=([64 * 64 * 3], []),\n",
" batch_size=32,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "8e99b0eb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(TensorSpec(shape=(None, 12288), dtype=tf.uint8, name=None),\n",
" TensorSpec(shape=(None,), dtype=tf.int64, name=None))"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset.element_spec"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "aa53b968",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-06-23 18:06:32.666834: E tensorflow/core/framework/dataset.cc:580] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:06:32.666861: E tensorflow/core/framework/dataset.cc:584] UNIMPLEMENTED: Cannot merge options for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:06:32.666985: E tensorflow/core/framework/dataset.cc:580] FAILED_PRECONDITION: Cannot compute input sources for dataset of type RootDataset, because sources could not be computed for input dataset of type IO>ArrowFeatherDataset\n"
]
}
],
"source": [
"for batch in dataset:\n",
" image_batch = batch[0].numpy().reshape(-1, 64, 64, 3)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "40567270",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(image_batch[19].squeeze())\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "20f99fe0",
"metadata": {},
"outputs": [],
"source": [
"!rm -rf /tmp/df.feather"
]
},
{
"cell_type": "markdown",
"id": "51216243",
"metadata": {},
"source": [
"### A batched dataset"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "8b95f20c",
"metadata": {},
"outputs": [],
"source": [
"# https://towardsdatascience.com/data-formats-for-training-in-tensorflow-parquet-petastorm-feather-and-more-e55179eeeb72\n",
"def get_dataset():\n",
" autotune = tf.data.AUTOTUNE\n",
" filenames = tf.data.Dataset.list_files(\"/tmp/*.feather\", shuffle=True)\n",
"\n",
" def make_ds(file):\n",
" ds = arrow_io.ArrowFeatherDataset(\n",
" [file],\n",
" [0, 1],\n",
" output_types=(tf.uint8, tf.int64),\n",
" output_shapes=([64 * 64 * 3], []),\n",
" batch_size=32,\n",
" )\n",
" return ds\n",
"\n",
" ds = filenames.interleave(make_ds, num_parallel_calls=autotune, deterministic=False)\n",
" return ds.prefetch(autotune)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "04c69170",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(TensorSpec(shape=(None, 12288), dtype=tf.uint8, name=None),\n",
" TensorSpec(shape=(None,), dtype=tf.int64, name=None))"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"parallel_ds = get_dataset()\n",
"parallel_ds.element_spec"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "b9db5b95",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-06-23 18:18:15.240738: E tensorflow/core/framework/dataset.cc:580] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:18:15.240761: E tensorflow/core/framework/dataset.cc:584] UNIMPLEMENTED: Cannot merge options for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:18:15.240894: E tensorflow/core/framework/dataset.cc:580] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:18:15.240914: E tensorflow/core/framework/dataset.cc:584] UNIMPLEMENTED: Cannot merge options for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:18:15.241008: E tensorflow/core/framework/dataset.cc:580] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:18:15.241028: E tensorflow/core/framework/dataset.cc:584] UNIMPLEMENTED: Cannot merge options for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:18:15.241122: E tensorflow/core/framework/dataset.cc:580] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:18:15.241139: E tensorflow/core/framework/dataset.cc:584] UNIMPLEMENTED: Cannot merge options for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n"
]
}
],
"source": [
"for batch in parallel_ds:\n",
" image_batch = batch[0].numpy().reshape(-1, 64, 64, 3)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "0c5ab26e",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(image_batch[19].squeeze())\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "4d217a2d",
"metadata": {},
"source": [
"## Explicit batching + mapping a preprocessing fn"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "c4b19e20",
"metadata": {},
"outputs": [],
"source": [
"# Mod + some taken from\n",
"# https://towardsdatascience.com/data-formats-for-training-in-tensorflow-parquet-petastorm-feather-and-more-e55179eeeb72\n",
"def get_dataset_no_batch():\n",
" autotune = tf.data.AUTOTUNE\n",
" filenames = tf.data.Dataset.list_files(\"/tmp/*.feather\", shuffle=True)\n",
"\n",
" def make_ds(file):\n",
" ds = arrow_io.ArrowFeatherDataset(\n",
" [file],\n",
" [0, 1],\n",
" output_types=(tf.uint8, tf.int64),\n",
" output_shapes=([64 * 64 * 3], []),\n",
" batch_mode=\"auto\",\n",
" )\n",
" return ds\n",
"\n",
" ds = filenames.interleave(make_ds, num_parallel_calls=autotune, deterministic=False)\n",
" ds = ds.unbatch()\n",
" ds = ds.batch(32)\n",
" ds = ds.map(\n",
" lambda x, y: (tf.cast(x, tf.float32) / 255, y), num_parallel_calls=autotune\n",
" )\n",
" return ds"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "19472fe0",
"metadata": {},
"outputs": [],
"source": [
"dataset_no_batch = get_dataset_no_batch()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "3fc5c84b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(TensorSpec(shape=(None, 12288), dtype=tf.float32, name=None),\n",
" TensorSpec(shape=(None,), dtype=tf.int64, name=None))"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset_no_batch.element_spec"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "53196855",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(32, 12288)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-06-23 18:19:45.585394: E tensorflow/core/framework/dataset.cc:580] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:19:45.585416: E tensorflow/core/framework/dataset.cc:584] UNIMPLEMENTED: Cannot merge options for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:19:45.585551: E tensorflow/core/framework/dataset.cc:580] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:19:45.585570: E tensorflow/core/framework/dataset.cc:584] UNIMPLEMENTED: Cannot merge options for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:19:45.585696: E tensorflow/core/framework/dataset.cc:580] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:19:45.585728: E tensorflow/core/framework/dataset.cc:584] UNIMPLEMENTED: Cannot merge options for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:19:45.585847: E tensorflow/core/framework/dataset.cc:580] UNIMPLEMENTED: Cannot compute input sources for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n",
"2022-06-23 18:19:45.585864: E tensorflow/core/framework/dataset.cc:584] UNIMPLEMENTED: Cannot merge options for dataset of type IO>ArrowFeatherDataset, because the dataset does not implement `InputDatasets`.\n"
]
}
],
"source": [
"for sample in dataset_no_batch.take(1):\n",
" print(sample[0].shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
tensorflow==2.9.1
tensorflow-io==0.26.0
pyarrow
tqdm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment