Skip to content

Instantly share code, notes, and snippets.

@saulshanabrook
Created November 25, 2019 18:01
Show Gist options
  • Save saulshanabrook/b35fc3f50ca8ee9fdd5585e48c081314 to your computer and use it in GitHub Desktop.
Save saulshanabrook/b35fc3f50ca8ee9fdd5585e48c081314 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import *\n",
"from typing import *"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class Array:\n",
" def __getitem__(self, idxs: Tuple[int, ...]) -> int:\n",
" ...\n",
" \n",
" def shape(self) -> Tuple[int, ...]:\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Arange(Array):\n",
" length: int\n",
" \n",
" def __getitem__(self, idxs):\n",
" return idxs[0]\n",
" \n",
" def shape(self):\n",
" return (self.length,)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class OuterProduct(Array):\n",
" left: Array\n",
" right: Array\n",
" \n",
" def __getitem__(self, indxs):\n",
" left_d = len(self.left.shape())\n",
" return self.left[indxs[:left_d]] * self.right[indxs[:left_d]]\n",
" \n",
" def shape(self):\n",
" return self.left.shape() + self.right.shape()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"range(0, 10)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"range(10)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"from typing_extensions import Protocol\n"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"%config ZMQInteractiveShell.ast_node_interactivity='last_expr_or_assign'"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"class SingleIter(Protocol):\n",
" def __call__(self) -> Union[None, Tuple[int, SingleIter]]:\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"MultiIter = Callable[[], Union[None, Tuple[int, List[MultiIter]]]]"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"class MultiIter(Protocol):\n",
" def __call__(self) -> Union[None, Tuple[int, List[MultiIter]]]:\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"class IterArray(Protocol):\n",
" def multi_iter(self) -> MultiIter:\n",
" ...\n",
" \n",
" def shape(self) -> Tuple[int, ...]:\n",
" s = []\n",
" for d in range(self.dim()):\n",
" length = 0\n",
" multi_iter = self.multi_iter()\n",
" while True:\n",
" res = multi_iter()\n",
" if not res:\n",
" break\n",
" _, multi_iters = res\n",
" multi_iter = multi_iters[d]\n",
" length += 1\n",
" s.append(length)\n",
" return s\n",
" \n",
" def dim(self) -> int:\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"def increment_index(indxs: Tuple[int, ...], dim: int) -> Tuple[int, ...]:\n",
" indxs = list(indxs)\n",
" indxs[dim] += 1\n",
" return tuple(indxs)\n",
"\n",
"@dataclass\n",
"class IterArange:\n",
" shape: Tuple[int]\n",
" current_position: Tuple[int]\n",
"\n",
" def __call__(self):\n",
" in_bounds = all(\n",
" dim_index < dim_length\n",
" for dim_length, dim_index\n",
" in zip(self.shape, self.current_position)\n",
" )\n",
" if not in_bounds:\n",
" return None\n",
" return sum(self.current_position), [\n",
" IterArange(\n",
" self.shape,\n",
" increment_index(self.current_position, d)\n",
" ) for d in range(len(self.shape))\n",
" ]\n",
" \n",
"@dataclass\n",
"class Arange(IterArray):\n",
" _shape: Tuple[int]\n",
" \n",
" def dim(self):\n",
" return len(self._shape)\n",
" \n",
" def multi_iter(self):\n",
" return IterArange(self._shape, (0,) * len(self._shape))"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"IterArange(shape=(5, 10, 2), current_position=(0, 0, 0))"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"i = Arange((5, 10, 2)).multi_iter()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class COO(IterArray):\n",
" original_array: IterArray\n",
" filter_fn: Callable[[int], bool] \n",
" \n",
" def dim(self):\n",
" return self.original_array.dim()\n",
" \n",
" def multi_iter(self):\n",
" return IterArange(self._shape, (0,) * len(self._shape))"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1,\n",
" [IterArange(shape=(5, 10, 2), current_position=(1, 0, 1)),\n",
" IterArange(shape=(5, 10, 2), current_position=(0, 1, 1)),\n",
" IterArange(shape=(5, 10, 2), current_position=(0, 0, 2))])"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@dataclass\n",
"class Filter(IterArray):\n",
" original_array: IterArray\n",
" filter_fn: Callable[[int], bool] \n",
" \n",
" def dim(self):\n",
" return self.original_array.dim()\n",
" \n",
" def multi_iter(self):\n",
" return IterArange(self._shape, (0,) * len(self._shape))"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class RangeIter:\n",
" current_state: int\n",
" length: int\n",
" \n",
" def __call__(self):\n",
" if self.current_state < self.length:\n",
" return (self.current_state, RangeIter(self.current_state + 1, self.length))\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"def range_iter(n):\n",
" return RangeIter(0, n)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"i = range_iter(1)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"current_value, new_i = i()"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"new_i()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"81"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class Array:\n",
" \n",
" def __i\n",
" def __getitem__(self, idxs: Tuple[int, ...]) -> int:\n",
" ...\n",
" \n",
" def shape(self) -> Tuple[int, ...]:\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"a = Array()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"current_val = next(a.iter(), 3)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment