Created
November 25, 2019 18:01
-
-
Save saulshanabrook/b35fc3f50ca8ee9fdd5585e48c081314 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from dataclasses import *\n", | |
"from typing import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Array:\n", | |
" def __getitem__(self, idxs: Tuple[int, ...]) -> int:\n", | |
" ...\n", | |
" \n", | |
" def shape(self) -> Tuple[int, ...]:\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@dataclass\n", | |
"class Arange(Array):\n", | |
" length: int\n", | |
" \n", | |
" def __getitem__(self, idxs):\n", | |
" return idxs[0]\n", | |
" \n", | |
" def shape(self):\n", | |
" return (self.length,)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@dataclass\n", | |
"class OuterProduct(Array):\n", | |
" left: Array\n", | |
" right: Array\n", | |
" \n", | |
" def __getitem__(self, indxs):\n", | |
" left_d = len(self.left.shape())\n", | |
" return self.left[indxs[:left_d]] * self.right[indxs[:left_d]]\n", | |
" \n", | |
" def shape(self):\n", | |
" return self.left.shape() + self.right.shape()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"range(0, 10)" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"range(10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from typing_extensions import Protocol\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import annotations" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%config ZMQInteractiveShell.ast_node_interactivity='last_expr_or_assign'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 55, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class SingleIter(Protocol):\n", | |
" def __call__(self) -> Union[None, Tuple[int, SingleIter]]:\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"MultiIter = Callable[[], Union[None, Tuple[int, List[MultiIter]]]]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 61, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class MultiIter(Protocol):\n", | |
" def __call__(self) -> Union[None, Tuple[int, List[MultiIter]]]:\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class IterArray(Protocol):\n", | |
" def multi_iter(self) -> MultiIter:\n", | |
" ...\n", | |
" \n", | |
" def shape(self) -> Tuple[int, ...]:\n", | |
" s = []\n", | |
" for d in range(self.dim()):\n", | |
" length = 0\n", | |
" multi_iter = self.multi_iter()\n", | |
" while True:\n", | |
" res = multi_iter()\n", | |
" if not res:\n", | |
" break\n", | |
" _, multi_iters = res\n", | |
" multi_iter = multi_iters[d]\n", | |
" length += 1\n", | |
" s.append(length)\n", | |
" return s\n", | |
" \n", | |
" def dim(self) -> int:\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def increment_index(indxs: Tuple[int, ...], dim: int) -> Tuple[int, ...]:\n", | |
" indxs = list(indxs)\n", | |
" indxs[dim] += 1\n", | |
" return tuple(indxs)\n", | |
"\n", | |
"@dataclass\n", | |
"class IterArange:\n", | |
" shape: Tuple[int]\n", | |
" current_position: Tuple[int]\n", | |
"\n", | |
" def __call__(self):\n", | |
" in_bounds = all(\n", | |
" dim_index < dim_length\n", | |
" for dim_length, dim_index\n", | |
" in zip(self.shape, self.current_position)\n", | |
" )\n", | |
" if not in_bounds:\n", | |
" return None\n", | |
" return sum(self.current_position), [\n", | |
" IterArange(\n", | |
" self.shape,\n", | |
" increment_index(self.current_position, d)\n", | |
" ) for d in range(len(self.shape))\n", | |
" ]\n", | |
" \n", | |
"@dataclass\n", | |
"class Arange(IterArray):\n", | |
" _shape: Tuple[int]\n", | |
" \n", | |
" def dim(self):\n", | |
" return len(self._shape)\n", | |
" \n", | |
" def multi_iter(self):\n", | |
" return IterArange(self._shape, (0,) * len(self._shape))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"IterArange(shape=(5, 10, 2), current_position=(0, 0, 0))" | |
] | |
}, | |
"execution_count": 86, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"i = Arange((5, 10, 2)).multi_iter()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@dataclass\n", | |
"class COO(IterArray):\n", | |
" original_array: IterArray\n", | |
" filter_fn: Callable[[int], bool] \n", | |
" \n", | |
" def dim(self):\n", | |
" return self.original_array.dim()\n", | |
" \n", | |
" def multi_iter(self):\n", | |
" return IterArange(self._shape, (0,) * len(self._shape))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 90, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(1,\n", | |
" [IterArange(shape=(5, 10, 2), current_position=(1, 0, 1)),\n", | |
" IterArange(shape=(5, 10, 2), current_position=(0, 1, 1)),\n", | |
" IterArange(shape=(5, 10, 2), current_position=(0, 0, 2))])" | |
] | |
}, | |
"execution_count": 90, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"@dataclass\n", | |
"class Filter(IterArray):\n", | |
" original_array: IterArray\n", | |
" filter_fn: Callable[[int], bool] \n", | |
" \n", | |
" def dim(self):\n", | |
" return self.original_array.dim()\n", | |
" \n", | |
" def multi_iter(self):\n", | |
" return IterArange(self._shape, (0,) * len(self._shape))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@dataclass\n", | |
"class RangeIter:\n", | |
" current_state: int\n", | |
" length: int\n", | |
" \n", | |
" def __call__(self):\n", | |
" if self.current_state < self.length:\n", | |
" return (self.current_state, RangeIter(self.current_state + 1, self.length))\n", | |
" return None" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def range_iter(n):\n", | |
" return RangeIter(0, n)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"i = range_iter(1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"current_value, new_i = i()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"new_i()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"81" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"class Array:\n", | |
" \n", | |
" def __i\n", | |
" def __getitem__(self, idxs: Tuple[int, ...]) -> int:\n", | |
" ...\n", | |
" \n", | |
" def shape(self) -> Tuple[int, ...]:\n", | |
" ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = Array()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"current_val = next(a.iter(), 3)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment