Skip to content

Instantly share code, notes, and snippets.

@asford
Last active January 3, 2018 02:03
Show Gist options
  • Save asford/5a630db35f08208adf107ea18a8f3f69 to your computer and use it in GitHub Desktop.
Save asford/5a630db35f08208adf107ea18a8f3f69 to your computer and use it in GitHub Desktop.
numpy metadata-with-numeric dtypes and function composition, curried framework
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import numpy",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import toolz\nfrom toolz import curry",
"execution_count": 40,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "@curry\ndef pairwise_broadcast_with_metadata(func, field, arr_a, arr_b, result_descr = None):\n # Generate broadcast input views\n ba, bb = numpy.broadcast_arrays(arr_a.reshape((-1, 1)), arr_b.reshape((1, -1)))\n \n assert field in arr_a.dtype.names and field in arr_b.dtype.names\n assert arr_a.dtype[field] == arr_b.dtype[field]\n \n if result_descr is None:\n result_field = field\n \n result_descr = [f for f in arr_a.dtype.descr if f[0] is field]\n assert len(result_descr) == 1\n result_descr = result_descr[0]\n else:\n # validate result_descr by parsing into a dtype\n result_dtype = numpy.dtype([result_descr])\n assert len(result_dtype.names) == 1\n \n result_field = list(result_dtype.names)[0]\n \n meta_a = [n for n in arr_a.dtype.names if n is not field]\n meta_b = [n for n in arr_b.dtype.names if n is not field]\n assert set(meta_a) == set(meta_b)\n \n final_dtype = numpy.dtype(\n [(f[0] + \"_a\",) + f[1:] for f in arr_a.dtype.descr if f[0] is not field] +\n [(f[0] + \"_b\",) + f[1:] for f in arr_b.dtype.descr if f[0] is not field] +\n [result_descr]\n )\n \n result = numpy.empty_like(ba, dtype=final_dtype)\n \n # Copy \"metadata\" into result\n for a in meta_a:\n result[a+\"_a\"] = ba[a]\n for b in meta_b:\n result[b+\"_b\"] = bb[b]\n \n # Perform op and store result\n result[result_field] = func(ba[field], bb[field])\n \n return result",
"execution_count": 41,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "meta_d = numpy.dtype([\n (\"i\", int),\n (\"f\", float, (2, 4))\n])",
"execution_count": 42,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def transpose_final_dims(d_array):\n assert d_array.ndim >= 2\n #assert d_array.shape[-2:] == (2,4)\n \n return d_array.transpose(\n tuple(range(d_array.ndim - 2)) + (-1, -2))",
"execution_count": 43,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "custom_mult = pairwise_broadcast_with_metadata(\n lambda a, b: a @ transpose_final_dims(b),\n \"f\",\n result_descr = (\"prod\", float, (2, 2))\n)\n\ncustom_add = pairwise_broadcast_with_metadata(\n lambda a, b: a + b,\n \"f\",\n result_descr = (\"sum\", float, (2, 4))\n)",
"execution_count": 50,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "sa= numpy.empty(3, meta_d)\nsa[\"i\"] = numpy.arange(3)\n\nsb= numpy.empty(4, meta_d)\nsb[\"i\"] = numpy.arange(10, 14)\n\nmult_result = custom_mult(sa, sb)\nsum_result = custom_add(sa, sb)",
"execution_count": 51,
"outputs": []
}
],
"metadata": {
"language_info": {
"name": "python",
"version": "3.5.4",
"file_extension": ".py",
"nbconvert_exporter": "python",
"mimetype": "text/x-python",
"pygments_lexer": "ipython3",
"codemirror_mode": {
"name": "ipython",
"version": 3
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/5a630db35f08208adf107ea18a8f3f69"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"toc_cell": false,
"toc_position": {},
"toc_section_display": "block",
"toc_window_display": false
},
"gist": {
"id": "5a630db35f08208adf107ea18a8f3f69",
"data": {
"description": "numpy metadata-with-numeric dtypes and function composition, curried framework",
"public": true
}
},
"kernelspec": {
"name": "conda-env-dev-py",
"display_name": "Python [conda env:dev]",
"language": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment