Skip to content

Instantly share code, notes, and snippets.

@MInner
Last active February 17, 2018 23:50
Show Gist options
  • Save MInner/ea549d7eb6b53cd8467d8d3b3a5d0d0b to your computer and use it in GitHub Desktop.
Save MInner/ea549d7eb6b53cd8467d8d3b3a5d0d0b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## API showcase"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.88690007 0.55995198 0.38361163 0.68403184 0.0185156 0.44171983\n",
" 0.31691619 0.02937568 0.00836834 0.78880494]\n",
"err\n",
"[ 0.7695575 0.32766094 0.23568485 0.20659444 0.71498636 0.24506314\n",
" 0.304722 0.57193584 0.61933445 0.24320615]\n",
"err\n",
"[0 0 0 0 0 0 0 0 0 0]\n",
"err\n",
"[[ 0.82250916 0.40552113 0.8301193 0.53392686 0.95118346 0.23415553\n",
" 0.5256737 0.24420565 0.82948549 0.86105571 0.77893775 0.46440709\n",
" 0.00869608 0.06291643 0.71569858 0.40135511 0.55258019 0.08750497\n",
" 0.1939549 0.40050567]\n",
" [ 0.08454198 0.6199718 0.79065577 0.79124625 0.87329852 0.89822278\n",
" 0.34901504 0.47744634 0.98617159 0.86818536 0.298728 0.94149995\n",
" 0.67553114 0.5602536 0.72955612 0.89932782 0.42081279 0.40602921\n",
" 0.89699766 0.83561404]\n",
" [ 0.59831343 0.35428523 0.78468076 0.18829544 0.61758936 0.14843397\n",
" 0.00596944 0.94308102 0.3325512 0.90827114 0.26069004 0.48808149\n",
" 0.63955319 0.5244585 0.57968004 0.46761899 0.47280379 0.74128998\n",
" 0.17967916 0.44656813]\n",
" [ 0.50523911 0.07422026 0.26768233 0.11226564 0.65410621 0.96435585\n",
" 0.05394845 0.14270239 0.5248568 0.78780012 0.90629328 0.27774761\n",
" 0.98978354 0.05009231 0.58057635 0.66617815 0.79489857 0.2342683\n",
" 0.83614349 0.01591979]\n",
" [ 0.04143335 0.61623312 0.64599089 0.41537789 0.23042656 0.43502103\n",
" 0.2977826 0.61752906 0.02659585 0.12120256 0.00599285 0.83223846\n",
" 0.18163289 0.22369256 0.23993126 0.28848517 0.56280879 0.04957049\n",
" 0.88140353 0.25244182]\n",
" [ 0.86919285 0.19760094 0.47385413 0.67153636 0.12651067 0.03326239\n",
" 0.18196367 0.01481561 0.40641939 0.66161099 0.33279351 0.15837677\n",
" 0.39526333 0.29394726 0.61702674 0.96257239 0.17626561 0.8929886\n",
" 0.87638873 0.06589202]\n",
" [ 0.67906735 0.50630077 0.69801478 0.13452838 0.62129216 0.20966711\n",
" 0.65118637 0.93092802 0.06673653 0.58469732 0.51138083 0.81524443\n",
" 0.61470865 0.23558816 0.55740143 0.39730801 0.44250091 0.9901409\n",
" 0.31236297 0.52510855]\n",
" [ 0.07341446 0.76972049 0.97572684 0.77730214 0.33045504 0.18508604\n",
" 0.22107951 0.38452896 0.7564924 0.67878656 0.45544665 0.94758416\n",
" 0.59257704 0.64189236 0.59865144 0.67881047 0.00933364 0.50467215\n",
" 0.63850554 0.3517403 ]\n",
" [ 0.73160472 0.75799142 0.67398197 0.37578636 0.88886733 0.63472282\n",
" 0.77206326 0.40438063 0.52069013 0.57340567 0.46108637 0.60409608\n",
" 0.67011748 0.09200092 0.15267226 0.35911962 0.34958308 0.13727258\n",
" 0.76444041 0.16011395]\n",
" [ 0.04931594 0.17037391 0.4409902 0.03624572 0.52629626 0.48542883\n",
" 0.91848478 0.64623007 0.64216531 0.48021692 0.96896661 0.94137667\n",
" 0.34870942 0.71793587 0.79579319 0.20487346 0.47547493 0.90524316\n",
" 0.15253653 0.7574594 ]]\n",
"err\n",
"[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]\n",
"err\n",
"[[[0 0 0 0]\n",
" [0 0 0 0]\n",
" [0 0 0 0]]\n",
"\n",
" [[0 0 0 0]\n",
" [0 0 0 0]\n",
" [0 0 0 0]]]\n",
"err\n"
]
}
],
"source": [
"import numpy as np\n",
"from typecheck import typecheck\n",
"\n",
"## NDArr object is specified in Definition section above\n",
"\n",
"## you can specify desired shape via [..] syntax\n",
"## or by passing shape=(..) value to constructor\n",
"@typecheck\n",
"def nice_func1(a:NDArr[10]):\n",
" return a\n",
"\n",
"print(nice_func1(np.random.rand(10))) ## works!\n",
"try:\n",
" print(nice_func1(np.random.rand(20))) ## fails!\n",
"except:\n",
" print('err')\n",
"\n",
"## does same thing as above\n",
"@typecheck\n",
"def nice_func1_1(a:NDArr(shape=(10,)) ):\n",
" return a\n",
"\n",
"print(nice_func1_1(np.random.rand(10))) ## works!\n",
"try:\n",
" print(nice_func1_1(np.random.rand(20))) ## fails!\n",
"except:\n",
" print('err')\n",
"\n",
"## you can also specify dtype restrictions\n",
"@typecheck\n",
"def nice_func2(a:NDArr(dtype=int)):\n",
" return a\n",
"\n",
"print(nice_func2(np.random.rand(10).astype(int))) ## works!\n",
"try:\n",
" print(nice_func2(np.random.rand(10))) ## fails!\n",
"except:\n",
" print('err')\n",
" \n",
"## by passing : value via [..] syntax you can\n",
"## restrict only subset of dimentions\n",
"@typecheck\n",
"def nice_func3(a:NDArr[:, 20]):\n",
" return a\n",
"\n",
"print(nice_func3(np.random.rand(10, 20))) ## works!\n",
"try:\n",
" print(nice_func3(np.random.rand(10, 10))) ## fails!\n",
"except:\n",
" print('err')\n",
"\n",
"## you can also do both!\n",
"## restrictions are getting rewritten\\added on demand\n",
"@typecheck\n",
"def nice_func4(a:NDArr[:, 20](dtype=int)):\n",
" return a\n",
"\n",
"print(nice_func4(np.random.rand(10, 20).astype(int))) ## works!\n",
"try:\n",
" print(nice_func4(np.random.rand(10, 20))) ## fails!\n",
"except:\n",
" print('err')\n",
" \n",
" \n",
"## or combine in any way\n",
"@typecheck\n",
"def nice_func5(a:NDArr(ndim=3, dtype=int)):\n",
" return a\n",
"print(nice_func5(np.random.rand(2, 3, 4).astype(int))) ## works!\n",
"try:\n",
" print(nice_func5(np.random.rand(2, 3).astype(int))) ## fails!\n",
"except:\n",
" print('err')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Definition"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from typecheck import typecheck, InputParameterError\n",
"\n",
"import abc\n",
"from typing import Callable, Any\n",
"\n",
"class ParameterizedTypeHintPredicate(object):\n",
" @abc.abstractmethod\n",
" def check(self, argument):\n",
" pass\n",
" \n",
" def __call__(self, argument=None, **kwargs):\n",
" if argument is not None:\n",
" return self.check(argument)\n",
" elif len(kwargs) > 0:\n",
" param_dict = self.__dict__.copy() # otherwise original class dameged too!\n",
" param_dict.update(kwargs)\n",
" return self.__class__(**param_dict)\n",
" else:\n",
" raise ValueError(\"Not enouth params for constructor\")\n",
" \n",
"class NumpyNdArrayTypeChecker(ParameterizedTypeHintPredicate):\n",
" def __init__(self, shape=None, dtype=None, ndim=None, dim_restrictions=None):\n",
" self.shape = shape\n",
" self.dtype = dtype\n",
" self.ndim = ndim\n",
" self.dim_restrictions = dim_restrictions\n",
" \n",
" def check(self, argument):\n",
" if not isinstance(argument, np.ndarray):\n",
" return False\n",
" \n",
" def cmp_if_not_none(a, b):\n",
" if a is not None:\n",
" return a == b\n",
" return True\n",
" \n",
" flag = True\n",
" flag &= cmp_if_not_none(self.shape, argument.shape)\n",
" flag &= cmp_if_not_none(self.dtype, argument.dtype)\n",
" flag &= cmp_if_not_none(self.ndim, argument.ndim)\n",
" \n",
" if self.dim_restrictions is not None:\n",
" for dim_n, dim_size in self.dim_restrictions:\n",
" if argument.shape[dim_n] != dim_size:\n",
" return False\n",
" \n",
" return flag\n",
" \n",
" def __getitem__(self, shape_specs):\n",
" \"\"\"\n",
" shape_specs=[:, :, 3, :, :, 5] \n",
" =>\n",
" ndim = 6\n",
" dim_restrictions=[(2, 3), (5, 5)]\n",
" # second dimention is of shape 3, fifth of shape 5\n",
" \"\"\"\n",
" if isinstance(shape_specs, int):\n",
" shape_specs = [shape_specs]\n",
" \n",
" if isinstance(shape_specs, slice):\n",
" shape_specs = [shape_specs]\n",
" \n",
" dim_rst = []\n",
" for i, dim_spec in enumerate(shape_specs):\n",
" if not isinstance(dim_spec, slice):\n",
" dim_rst.append((i, dim_spec))\n",
" \n",
" return self(ndim=len(shape_specs), dim_restrictions=dim_rst)\n",
" \n",
"NDArr = NumpyNdArrayTypeChecker()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tests:"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All tests have passed!\n"
]
}
],
"source": [
"@typecheck\n",
"def test_fails(f:callable, err=Exception):\n",
" flag = True\n",
" try:\n",
" f()\n",
" except err:\n",
" flag = False\n",
" \n",
" if flag:\n",
" raise RuntimeError('Test failed! Try-catch shoud have '\n",
" 'raised error' + ('' if err is Exception else str(err)))\n",
"\n",
"@typecheck\n",
"def test_passes(f:callable):\n",
" try:\n",
" f()\n",
" except:\n",
" raise RuntimeError('Test failed! (must be no errors here)')\n",
"\n",
"def test():\n",
" ## test for test :) - internal test_fails should pop up error\n",
" test_fails(lambda: test_fails(lambda: 1/0, err=InputParameterError), err=ZeroDivisionError)\n",
" test_fails(lambda: test_passes(lambda: 1/0), err=RuntimeError)\n",
" \n",
" @typecheck\n",
" def func1(a: np.ndarray) -> np.ndarray :\n",
" return a\n",
"\n",
" test_passes(lambda: func1(np.arange(10)))\n",
" test_fails(lambda: func1('str'), err=InputParameterError)\n",
"\n",
" @typecheck\n",
" def func2(a: NDArr) -> NDArr :\n",
" return a\n",
"\n",
" test_passes(lambda: func2(np.arange(10)))\n",
" test_fails(lambda: func2('str'), err=InputParameterError)\n",
" \n",
" @typecheck\n",
" def func3(a: NDArr(shape=(10,))) -> NDArr :\n",
" return a\n",
"\n",
" test_passes(lambda: func3(np.arange(10)))\n",
" test_fails(lambda: func3('str'), err=InputParameterError)\n",
" test_fails(lambda: func3(np.arange(20)), err=InputParameterError)\n",
" \n",
" \n",
" @typecheck\n",
" def func4(a: NDArr(shape=(10,), dtype=np.int32)) -> NDArr :\n",
" return a\n",
"\n",
" test_passes(lambda: func4(np.arange(10, dtype=np.int32)))\n",
" test_fails(lambda: func4('str'), err=InputParameterError)\n",
" test_fails(lambda: func4(np.arange(20)), err=InputParameterError)\n",
" test_fails(lambda: func4(np.arange(10, dtype=float)), err=InputParameterError)\n",
" test_fails(lambda: func4(np.arange(10, dtype=int)))\n",
" \n",
" @typecheck\n",
" def func5(a: NDArr[10]) -> NDArr :\n",
" return a\n",
"\n",
" test_passes(lambda: func5(np.arange(10)))\n",
" test_fails(lambda: func5('str'), err=InputParameterError)\n",
" test_fails(lambda: func5(np.arange(20)), err=InputParameterError)\n",
" \n",
" @typecheck\n",
" def func5_1(a: NDArr[10, 20]) -> NDArr :\n",
" return a\n",
"\n",
" test_passes(lambda: func5_1(np.random.rand(10, 20)))\n",
" test_fails(lambda: func5_1('str'), err=InputParameterError)\n",
" test_fails(lambda: func5_1(np.arange(20)), err=InputParameterError)\n",
" test_fails(lambda: func5_1(np.random.rand(10, 30)), err=InputParameterError)\n",
" \n",
" @typecheck\n",
" def func6(a: NDArr[10, :]) -> NDArr :\n",
" return a\n",
"\n",
" test_passes(lambda: func6(np.random.rand(10, 20)))\n",
" test_passes(lambda: func6(np.random.rand(10, 30)))\n",
" test_fails(lambda: func6('str'), err=InputParameterError)\n",
" test_fails(lambda: func6(np.arange(20)), err=InputParameterError)\n",
" test_fails(lambda: func6(np.arange(10)), err=InputParameterError)\n",
" \n",
" @typecheck\n",
" def func6_1(a: NDArr[:]) -> NDArr :\n",
" return a\n",
"\n",
" test_passes(lambda: func6_1(np.random.rand(10)))\n",
" test_passes(lambda: func6_1(np.random.rand(20)))\n",
" test_fails(lambda: func6_1('str'), err=InputParameterError)\n",
" test_fails(lambda: func6_1(np.random.rand(10, 1)), err=InputParameterError)\n",
" test_fails(lambda: func6_1(np.random.rand(10, 20)), err=InputParameterError)\n",
" \n",
" @typecheck\n",
" def func7(a: NDArr(dtype=int)[10, :]) -> NDArr :\n",
" return a\n",
"\n",
" test_passes(lambda: func7(np.random.rand(10, 20).astype(dtype=int)))\n",
" test_passes(lambda: func7(np.random.rand(10, 30).astype(dtype=int)))\n",
" test_fails(lambda: func7('str'), err=InputParameterError)\n",
" test_fails(lambda: func7(np.arange(20)), err=InputParameterError)\n",
" test_fails(lambda: func7(np.arange(10)), err=InputParameterError)\n",
" test_fails(lambda: func7(np.random.rand(10, 20)))\n",
" test_fails(lambda: func7(np.random.rand(10, 30)))\n",
"\n",
" \n",
" @typecheck\n",
" def func8(a: NDArr[10, :](dtype=int)) -> NDArr :\n",
" return a\n",
"\n",
" test_passes(lambda: func8(np.random.rand(10, 20).astype(dtype=int)))\n",
" test_passes(lambda: func8(np.random.rand(10, 30).astype(dtype=int)))\n",
" test_fails(lambda: func8('str'), err=InputParameterError)\n",
" test_fails(lambda: func8(np.arange(20)), err=InputParameterError)\n",
" test_fails(lambda: func8(np.arange(10)), err=InputParameterError)\n",
" test_fails(lambda: func8(np.random.rand(10, 20)))\n",
" test_fails(lambda: func8(np.random.rand(10, 30)))\n",
"\n",
" print('All tests have passed!')\n",
" \n",
"test()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment