Skip to content

Instantly share code, notes, and snippets.

@KaiLicht
Last active February 22, 2020 23:57
Show Gist options
  • Save KaiLicht/1dda20c8c7c17ccbfe5e22606ba47522 to your computer and use it in GitHub Desktop.
Save KaiLicht/1dda20c8c7c17ccbfe5e22606ba47522 to your computer and use it in GitHub Desktop.
Image class for multiple image types
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from typing import Union, List, Tuple, Callable\n",
"\n",
"from fastcore.all import patch, patch_to\n",
"from inspect import getmembers, ismethod\n",
"import PIL.Image\n",
"import numpy as np\n",
"import torch\n",
"from nbdev.showdoc import *"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ImgTypes = Union[torch.tensor, np.array, PIL.Image.Image]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Base class"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"multiImageTypes = [(\"numpy\", np.ndarray),\n",
" (\"torch\", torch.Tensor),\n",
" (\"PIL\", PIL.Image.Image)]\n",
"class MultiImage():\n",
" \"Base class for multi images.\"\n",
" def __init__(self, img:ImgTypes, ann:dict=None):\n",
" self.img = img\n",
" self.allTypes = multiImageTypes\n",
" for t in self.allTypes: \n",
" if isinstance(img, t[1]): self.type=t[0]\n",
" assert hasattr(self, 'type'), \"Image type is not supported!\"\n",
" self._dispFuns = self._get_dispatch_funs()\n",
" self._genericFuns = list(set([f[0] for f in self._dispFuns]))\n",
" for f in self._genericFuns:\n",
" if (f, self.type) in self._dispFuns: \n",
" dispatchFun = getattr(self, f\"_disp_{f}_{self.type}\")\n",
" setattr(self, f, dispatchFun)\n",
" else:\n",
" setattr(self, f, self._generic_not_implemented)\n",
"\n",
" def _generic_not_implemented(self, *args, **kwargs):\n",
" raise NotImplementedError(f\"Method not implemented for {self.type} image\")\n",
"\n",
" def _get_dispatch_funs(self):\n",
" \"Gets the functions to dispatch.\"\n",
" funs = [f[0] for f in getmembers(self, predicate=ismethod) if f[0].startswith('_disp_')]\n",
" return [(\"_\".join(f.split(\"_\")[2:-1]), f.split(\"_\")[-1]) for f in funs]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"show_doc(MultiImage, title_level=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This base class takes an image of type `Union[torch.tensor, np.array, PIL.Image.Image]` and stores it as class member together with the image type as string. All methods of this class are implemented in the following cells. \n",
"* To see the code of the base class click on the source button in the upper right.\n",
"* **Add new method:** To add a new method that works on all different image types add the name of the function to the `funs` attribute in the constructor and use the template below to implement one function per image type. For example if you want to add a method `myfun`: Implement each function with the naming scheme `_disp_myfun_<img_type>` like in the template below. You can read it like: Dispatch `myfun` to `torch`. During runtime the correct method is dispatched. The `_disp_` prefix exposes the function to the dispatcher. If there's only a subset of methods implemented for a type a generic function with a `NotImplementedError` will be dispatched.\n",
"* **Add new image type:** To add a new type of image add a tuple with `(\"strName\", type)` to `multiImages`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Add function for all types"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"@patch\n",
"def _disp_myfun_torch(self:MultiImage, myArg=1):\n",
" return f\"{type(self.img)} | torch-function | {myArg}\"\n",
"\n",
"@patch\n",
"def _disp_myfun_numpy(self:MultiImage, myArg=2):\n",
" return f\"{type(self.img)} | numpy-function | {myArg}\"\n",
"\n",
"@patch\n",
"def _disp_myfun_PIL(self:MultiImage, myArg=3):\n",
" return f\"{type(self.img)} | PIL-function | {myArg}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"show_doc(MultiImage._get_rnd_color_torch, title_level=3, name=\"MultiImage.get_rnd_color\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test it:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"imgTorch = torch.rand(128,128,3)\n",
"imgNumpy = np.random.rand(128,128,3)\n",
"imgPIL = PIL.Image.fromarray(np.random.rand(128,128,3), 'RGB')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'torch.Tensor'> | torch-function | 1\n",
"<class 'numpy.ndarray'> | numpy-function | 2\n",
"<class 'PIL.Image.Image'> | PIL-function | 3\n"
]
}
],
"source": [
"print(MultiImage(imgTorch).myfun())\n",
"print(MultiImage(imgNumpy).myfun())\n",
"print(MultiImage(imgPIL).myfun())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment