Created
August 27, 2019 13:53
-
-
Save mjurkus/2677bbf9461d963411217f1be223e553 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pathlib import Path\n", | |
"import numpy as np\n", | |
"import os\n", | |
"import re\n", | |
"import math\n", | |
"import tensorflow as tf\n", | |
"import copy\n", | |
"from typing import Tuple, Optional, Callable, List" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from toai.image import ImageAugmentor, ImageDataset, ImageParser, ImageResizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ImageDatasetConfig:\n", | |
"\n", | |
" def __init__(\n", | |
" self,\n", | |
" img_dims: Tuple[int, int, int],\n", | |
" parallel_calls: int = 1,\n", | |
" prefetch: int = 1,\n", | |
" preprocess_pipeline: List[Callable] = [],\n", | |
" batch_size: int = 8,\n", | |
" shuffle: bool = False,\n", | |
" regexp: str = None\n", | |
" ):\n", | |
" self.parallel_calls = parallel_calls\n", | |
" self.img_dims = img_dims\n", | |
" self.prefetch = prefetch\n", | |
" self.preprocess_pipeline = preprocess_pipeline\n", | |
" self.batch_size = batch_size\n", | |
" self.shuffle = shuffle\n", | |
" self.regexp = regexp\n", | |
"\n", | |
" def copy(self, preprocess_pipeline: List[Callable], shuffle: bool = False) -> 'ImageDatasetConfig':\n", | |
" new = copy.deepcopy(self)\n", | |
" new.shuffle = shuffle\n", | |
" new.preprocess_pipeline = preprocess_pipeline\n", | |
" return new\n", | |
"\n", | |
"\n", | |
"class ImageDataset:\n", | |
" data: tf.data.Dataset\n", | |
" x: np.ndarray\n", | |
" y: np.ndarray\n", | |
" length: int\n", | |
" steps: int\n", | |
" classes: np.ndarray\n", | |
" n_classes: int\n", | |
"\n", | |
" def __init__(self, config: ImageDatasetConfig):\n", | |
" self.config = config\n", | |
"\n", | |
" def build_from_df(self) -> 'ImageDataset':\n", | |
" pass\n", | |
"\n", | |
" def build_from_path(self, path: Path, default_label: Optional[str] = None) -> 'ImageDataset':\n", | |
" paths = []\n", | |
" labels = []\n", | |
"\n", | |
" if self.config.regexp:\n", | |
" for value in os.listdir(str(path)):\n", | |
" match = re.match(self.config.regexp, value)\n", | |
" if match:\n", | |
" labels.append(match.group(1))\n", | |
" elif default_label:\n", | |
" labels.append(default_label)\n", | |
" else:\n", | |
" raise ValueError(f\"No match found and no default value provided for value: {value}\")\n", | |
"\n", | |
" paths.append(f\"{path}/{value}\")\n", | |
" else:\n", | |
" raise ValueError(\"Unexpected configuration\")\n", | |
"\n", | |
" return self.__build(np.asarray(paths), np.asarray(labels))\n", | |
"\n", | |
" def __build(self, x: np.ndarray, y: np.ndarray) -> 'ImageDataset':\n", | |
" self.x = x\n", | |
" self.y = y\n", | |
" self.length = len(y)\n", | |
" self.classes = np.unique(y)\n", | |
" self.n_classes = len(self.classes)\n", | |
" self.steps = math.ceil(self.length / self.config.batch_size)\n", | |
"\n", | |
" image_ds = tf.data.Dataset.from_tensor_slices(x)\n", | |
"\n", | |
" for fun in self.config.preprocess_pipeline:\n", | |
" image_ds = image_ds.map(fun, num_parallel_calls=self.config.parallel_calls)\n", | |
"\n", | |
" label_ds = tf.data.Dataset.from_tensor_slices(y.astype(float))\n", | |
" dataset = tf.data.Dataset.zip((image_ds, label_ds))\n", | |
"\n", | |
" if self.config.shuffle:\n", | |
" dataset = dataset.shuffle(self.config.batch_size)\n", | |
"\n", | |
" self.data = dataset.batch(self.config.batch_size).repeat().prefetch(self.config.prefetch)\n", | |
"\n", | |
" return self" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"DATA_DIR = Path(\"food\")\n", | |
"IMG_DIMS = (224, 224, 3)\n", | |
"\n", | |
"base_config = ImageDatasetConfig(\n", | |
" img_dims=IMG_DIMS,\n", | |
" parallel_calls=1,\n", | |
" batch_size=8,\n", | |
" prefetch=1,\n", | |
" regexp=r\"(\\d+)\"\n", | |
")\n", | |
"\n", | |
"train = ImageDataset(\n", | |
" base_config\n", | |
" .copy(\n", | |
" preprocess_pipeline = [\n", | |
" ImageParser(),\n", | |
" ImageResizer(img_dims=IMG_DIMS, resize=\"random_crop\", crop_adjustment=1.6),\n", | |
" ImageAugmentor(level=3, flips=\"both\"),\n", | |
" ],\n", | |
" shuffle = True\n", | |
" )\n", | |
").build_from_path(DATA_DIR/'training')\n", | |
"\n", | |
"validation = ImageDataset(\n", | |
" base_config\n", | |
" .copy(\n", | |
" preprocess_pipeline = [\n", | |
" ImageParser(),\n", | |
" ImageResizer(img_dims=IMG_DIMS, resize=\"crop\", crop_adjustment=1.0),\n", | |
" ]\n", | |
" )\n", | |
").build_from_path(DATA_DIR/'validation')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment