Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created August 1, 2017 06:47
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/c700193625347eb68fee4d1f0dc8c0c8 to your computer and use it in GitHub Desktop.
Save shoyer/c700193625347eb68fee4d1f0dc8c0c8 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "NumPy vindex.ipynb",
"version": "0.3.2",
"views": {},
"default_view": {},
"provenance": [
{
"file_id": "/piper/depot/google3/research/colab/frontend/notebooks/scratchpad.ipynb",
"timestamp": 1501548996561
},
{
"file_id": "0Bx_pzjPHF_34bEdnazB5SkNiRHc",
"timestamp": 1468447836766
}
]
}
},
"cells": [
{
"metadata": {
"id": "IfoQHanqhTKS",
"colab_type": "text"
},
"source": [
"# Pure Python ``vindex`` implementation\n",
"\n",
"This is a prototype of the functionality described in [this proposal](https://github.com/numpy/numpy/pull/6256)."
],
"cell_type": "markdown"
},
{
"metadata": {
"id": "lIYdn1woOS1n",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"source": [
"# Copyright 2017 Google Inc.\n",
"\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\n",
"import numpy as np\n",
"\n",
"\n",
"def is_contiguous(positions):\n",
" \"\"\"Given a non-empty list, does it consist of contiguous integers?\"\"\"\n",
" previous = positions[0]\n",
" for current in positions[1:]:\n",
" if current != previous + 1:\n",
" return False\n",
" previous = current\n",
" return True\n",
"\n",
"\n",
"def advanced_indexer_subspaces(key):\n",
" \"\"\"Indices of the advanced indexes subspaces for mixed indexing and vindex.\n",
" \"\"\"\n",
" if not isinstance(key, tuple):\n",
" key = (key,)\n",
" advanced_index_positions = [i for i, k in enumerate(key)\n",
" if not isinstance(k, slice)]\n",
"\n",
" if (not advanced_index_positions or\n",
" not is_contiguous(advanced_index_positions)):\n",
" # nothing to reorder\n",
" return (), ()\n",
"\n",
" non_slices = [k for k in key if not isinstance(k, slice)]\n",
" ndim = len(np.broadcast(*non_slices).shape)\n",
" mixed_positions = advanced_index_positions[0] + np.arange(ndim)\n",
" vindex_positions = np.arange(ndim)\n",
" return mixed_positions, vindex_positions\n",
"\n",
"\n",
"class VectorizedIndexer(object):\n",
" def __init__(self, array):\n",
" self._array = array\n",
"\n",
" def __getitem__(self, key):\n",
" mixed_positions, vindex_positions = advanced_indexer_subspaces(key)\n",
" return np.moveaxis(self._array[key], mixed_positions, vindex_positions)\n",
"\n",
" def __setitem__(self, key, value):\n",
" mixed_positions, vindex_positions = advanced_indexer_subspaces(key)\n",
" self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions)\n",
"\n",
" \n",
"class VindexArray(np.ndarray):\n",
" @property\n",
" def vindex(self):\n",
" return VectorizedIndexer(self)"
],
"cell_type": "code",
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "hILp9Zi1hNo0",
"colab_type": "text"
},
"source": [
"## Tests"
],
"cell_type": "markdown"
},
{
"metadata": {
"id": "giAX7lwff6Mc",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"source": [
"assert is_contiguous([1, 2, 3])\n",
"assert not is_contiguous([1, 3])\n",
"\n",
"x = np.arange(3 * 4 * 5).reshape((3, 4, 5)).view(VindexArray)\n",
"np.testing.assert_array_equal(x.vindex[0], x[0])\n",
"np.testing.assert_array_equal(x.vindex[[1, 2], [1, 2]], x[[1, 2], [1, 2]])\n",
"assert x.vindex[[0, 1], [0, 1], :].shape == (2, 5)\n",
"assert x.vindex[[0, 1], :, [0, 1]].shape == (2, 4)\n",
"assert x.vindex[:, [0, 1], [0, 1]].shape == (2, 3)\n",
"# assignment should not raise\n",
"x.vindex[[0, 1], [0, 1], :] = x.vindex[[0, 1], [0, 1], :]\n",
"x.vindex[[0, 1], :, [0, 1]] = x.vindex[[0, 1], :, [0, 1]]\n",
"x.vindex[:, [0, 1], [0, 1]] = x.vindex[:, [0, 1], [0, 1]]"
],
"cell_type": "code",
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment