Skip to content

Instantly share code, notes, and snippets.

@mjurkus
Created August 27, 2019 13:53
Show Gist options
  • Save mjurkus/2677bbf9461d963411217f1be223e553 to your computer and use it in GitHub Desktop.
Save mjurkus/2677bbf9461d963411217f1be223e553 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import numpy as np\n",
"import os\n",
"import re\n",
"import math\n",
"import tensorflow as tf\n",
"import copy\n",
"from typing import Tuple, Optional, Callable, List"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from toai.image import ImageAugmentor, ImageDataset, ImageParser, ImageResizer"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"class ImageDatasetConfig:\n",
"\n",
" def __init__(\n",
" self,\n",
" img_dims: Tuple[int, int, int],\n",
" parallel_calls: int = 1,\n",
" prefetch: int = 1,\n",
" preprocess_pipeline: List[Callable] = [],\n",
" batch_size: int = 8,\n",
" shuffle: bool = False,\n",
" regexp: str = None\n",
" ):\n",
" self.parallel_calls = parallel_calls\n",
" self.img_dims = img_dims\n",
" self.prefetch = prefetch\n",
" self.preprocess_pipeline = preprocess_pipeline\n",
" self.batch_size = batch_size\n",
" self.shuffle = shuffle\n",
" self.regexp = regexp\n",
"\n",
" def copy(self, preprocess_pipeline: List[Callable], shuffle: bool = False) -> 'ImageDatasetConfig':\n",
" new = copy.deepcopy(self)\n",
" new.shuffle = shuffle\n",
" new.preprocess_pipeline = preprocess_pipeline\n",
" return new\n",
"\n",
"\n",
"class ImageDataset:\n",
" data: tf.data.Dataset\n",
" x: np.ndarray\n",
" y: np.ndarray\n",
" length: int\n",
" steps: int\n",
" classes: np.ndarray\n",
" n_classes: int\n",
"\n",
" def __init__(self, config: ImageDatasetConfig):\n",
" self.config = config\n",
"\n",
" def build_from_df(self) -> 'ImageDataset':\n",
" pass\n",
"\n",
" def build_from_path(self, path: Path, default_label: Optional[str] = None) -> 'ImageDataset':\n",
" paths = []\n",
" labels = []\n",
"\n",
" if self.config.regexp:\n",
" for value in os.listdir(str(path)):\n",
" match = re.match(self.config.regexp, value)\n",
" if match:\n",
" labels.append(match.group(1))\n",
" elif default_label:\n",
" labels.append(default_label)\n",
" else:\n",
" raise ValueError(f\"No match found and no default value provided for value: {value}\")\n",
"\n",
" paths.append(f\"{path}/{value}\")\n",
" else:\n",
" raise ValueError(\"Unexpected configuration\")\n",
"\n",
" return self.__build(np.asarray(paths), np.asarray(labels))\n",
"\n",
" def __build(self, x: np.ndarray, y: np.ndarray) -> 'ImageDataset':\n",
" self.x = x\n",
" self.y = y\n",
" self.length = len(y)\n",
" self.classes = np.unique(y)\n",
" self.n_classes = len(self.classes)\n",
" self.steps = math.ceil(self.length / self.config.batch_size)\n",
"\n",
" image_ds = tf.data.Dataset.from_tensor_slices(x)\n",
"\n",
" for fun in self.config.preprocess_pipeline:\n",
" image_ds = image_ds.map(fun, num_parallel_calls=self.config.parallel_calls)\n",
"\n",
" label_ds = tf.data.Dataset.from_tensor_slices(y.astype(float))\n",
" dataset = tf.data.Dataset.zip((image_ds, label_ds))\n",
"\n",
" if self.config.shuffle:\n",
" dataset = dataset.shuffle(self.config.batch_size)\n",
"\n",
" self.data = dataset.batch(self.config.batch_size).repeat().prefetch(self.config.prefetch)\n",
"\n",
" return self"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"DATA_DIR = Path(\"food\")\n",
"IMG_DIMS = (224, 224, 3)\n",
"\n",
"base_config = ImageDatasetConfig(\n",
" img_dims=IMG_DIMS,\n",
" parallel_calls=1,\n",
" batch_size=8,\n",
" prefetch=1,\n",
" regexp=r\"(\\d+)\"\n",
")\n",
"\n",
"train = ImageDataset(\n",
" base_config\n",
" .copy(\n",
" preprocess_pipeline = [\n",
" ImageParser(),\n",
" ImageResizer(img_dims=IMG_DIMS, resize=\"random_crop\", crop_adjustment=1.6),\n",
" ImageAugmentor(level=3, flips=\"both\"),\n",
" ],\n",
" shuffle = True\n",
" )\n",
").build_from_path(DATA_DIR/'training')\n",
"\n",
"validation = ImageDataset(\n",
" base_config\n",
" .copy(\n",
" preprocess_pipeline = [\n",
" ImageParser(),\n",
" ImageResizer(img_dims=IMG_DIMS, resize=\"crop\", crop_adjustment=1.0),\n",
" ]\n",
" )\n",
").build_from_path(DATA_DIR/'validation')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment