Last active
May 16, 2023 10:19
-
-
Save ev-br/27dee81aae8e24193db8082ee886f6e4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "6da0c516", | |
"metadata": {}, | |
"source": [ | |
"## Type promotion with scalars under NEP50\n", | |
"\n", | |
"Assuming that `np.result_type` does type promotion of two arrays, mimic what it does with scalar + array." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "1397d81b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"np._set_promotion_state(\"weak\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "237c6f8b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"categories = [('bool',),\n", | |
" ('uint8', 'int8', 'int16', 'int32', 'int64'), # plus other unsigned ints\n", | |
" ('float16', 'float32', 'float64'), # plus longdouble\n", | |
" ('complex64', 'complex128')]\n", | |
"\n", | |
"\n", | |
"def category(dtyp):\n", | |
" for j, cat in enumerate(categories):\n", | |
" if dtyp in cat:\n", | |
" return j\n", | |
" raise ValueError(f\"{dtyp}?\")\n", | |
"\n", | |
"\n", | |
"dtype_for_cat = {0: \"bool\",\n", | |
" 1: \"int64\", # ignore windows int32\n", | |
" 2: \"float64\",\n", | |
" 3: \"complex128\"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "4b79d2da", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def pre_promote(weak_dtype, non_weak_dtype):\n", | |
" \"\"\"Return the dtype for the promoted arguments.\"\"\"\n", | |
" cat_weak = category(weak_dtype)\n", | |
" cat_non_weak = category(non_weak_dtype)\n", | |
" if cat_weak <= cat_non_weak:\n", | |
" # ignore the 'weak' dtype\n", | |
" return non_weak_dtype, non_weak_dtype\n", | |
" else:\n", | |
" # weak dtype -> default category\n", | |
" return dtype_for_cat[cat_weak], non_weak_dtype\n", | |
" \n", | |
"def result_type(weak, non_weak):\n", | |
" \"\"\"The dtype of `binop(weak, non_weak)` under NEP 50\"\"\"\n", | |
" weak, non_weak = np.asarray(weak), np.asarray(non_weak)\n", | |
" dt1, dt2 = pre_promote(weak.dtype, non_weak.dtype)\n", | |
" dtyp = np.result_type(dt1, dt2)\n", | |
" return dtyp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "2c0b937a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dtype('float32')" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"result_type(True, np.float32(2))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "cdecdf01", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dtype('float32')" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"(True + np.float32(2)).dtype" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "c1f35436", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"weak + non-weak --> result_type / np_dtype\n", | |
"----------------------------------------------\n", | |
"True + bool --> bool / bool \n", | |
"True + uint8 --> uint8 / uint8 \n", | |
"True + int8 --> int8 / int8 \n", | |
"True + int32 --> int32 / int32 \n", | |
"True + int64 --> int64 / int64 \n", | |
"True + float32 --> float32 / float32 \n", | |
"True + float64 --> float64 / float64 \n", | |
"True + complex64 --> complex64 / complex64 \n", | |
"True + complex128 --> complex128 / complex128\n", | |
"1 + bool --> int64 / int64 \n", | |
"1 + uint8 --> uint8 / uint8 \n", | |
"1 + int8 --> int8 / int8 \n", | |
"1 + int32 --> int32 / int32 \n", | |
"1 + int64 --> int64 / int64 \n", | |
"1 + float32 --> float32 / float32 \n", | |
"1 + float64 --> float64 / float64 \n", | |
"1 + complex64 --> complex64 / complex64 \n", | |
"1 + complex128 --> complex128 / complex128\n", | |
"2.0 + bool --> float64 / float64 \n", | |
"2.0 + uint8 --> float64 / float64 \n", | |
"2.0 + int8 --> float64 / float64 \n", | |
"2.0 + int32 --> float64 / float64 \n", | |
"2.0 + int64 --> float64 / float64 \n", | |
"2.0 + float32 --> float32 / float32 \n", | |
"2.0 + float64 --> float64 / float64 \n", | |
"2.0 + complex64 --> complex64 / complex64 \n", | |
"2.0 + complex128 --> complex128 / complex128\n", | |
"3j + bool --> complex128 / complex128\n", | |
"3j + uint8 --> complex128 / complex128\n", | |
"3j + int8 --> complex128 / complex128\n", | |
"3j + int32 --> complex128 / complex128\n", | |
"3j + int64 --> complex128 / complex128\n", | |
"3j + float32 --> complex128 / complex64 \n", | |
"3j + float64 --> complex128 / complex128\n", | |
"3j + complex64 --> complex64 / complex64 \n", | |
"3j + complex128 --> complex128 / complex128\n" | |
] | |
} | |
], | |
"source": [ | |
"weaks = [True, 1, 2.0, 3j]\n", | |
"non_weaks = [np.asarray(True),\n", | |
" np.uint8(1), np.int8(1), np.int32(1), np.int64(1),\n", | |
" np.float32(1), np.float64(1),\n", | |
" np.complex64(1), np.complex128(1)]\n", | |
"\n", | |
"import itertools\n", | |
"print(\"weak + non-weak --> result_type / np_dtype\")\n", | |
"print(\"----------------------------------------------\")\n", | |
"for w, n in itertools.product(weaks, non_weaks):\n", | |
" res = result_type(w, n)\n", | |
" res_np = np.add(w, n).dtype\n", | |
" print(f\"{str(w):<5} + {n.dtype.name:<10} --> {res.name:<10} / {res_np.name:<10}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "7c2c2353", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"False" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.float64(0.1) == np.float32(0.1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "433118f8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"0.1 == np.float32(0.1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "dd09a80b", | |
"metadata": {}, | |
"source": [ | |
"## Use \"inexact\" as a category (floats and complex)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "0acacf8a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"categories = [('bool',),\n", | |
" ('uint8', 'int8', 'int16', 'int32', 'int64'), # plus other unsigned ints\n", | |
" ('float16', 'float32', 'float64', # plus longdouble\n", | |
" 'complex64', 'complex128')]\n", | |
"\n", | |
"\n", | |
"def category(dtyp):\n", | |
" for j, cat in enumerate(categories):\n", | |
" if dtyp in cat:\n", | |
" return j\n", | |
" raise ValueError(f\"{dtyp}?\")\n", | |
"\n", | |
"\n", | |
"dtype_for_cat = {0: \"bool\",\n", | |
" 1: \"int64\", # ignore windows int32\n", | |
" 2: \"float64\",\n", | |
" }\n", | |
"\n", | |
"\n", | |
"def pre_promote(weak_dtype, non_weak_dtype):\n", | |
" \"\"\"Return the dtype for the promoted arguments.\"\"\"\n", | |
" cat_weak = category(weak_dtype)\n", | |
" cat_non_weak = category(non_weak_dtype)\n", | |
" if cat_weak <= cat_non_weak:\n", | |
" # ignore the 'weak' dtype\n", | |
" return non_weak_dtype, non_weak_dtype\n", | |
" else:\n", | |
" # weak dtype -> default category\n", | |
" return dtype_for_cat[cat_weak], non_weak_dtype\n", | |
"\n", | |
" \n", | |
"def result_type(weak, non_weak):\n", | |
" \"\"\"The dtype of `binop(weak, non_weak)` under NEP 50\"\"\"\n", | |
" weak, non_weak = np.asarray(weak), np.asarray(non_weak)\n", | |
" dt1, dt2 = pre_promote(weak.dtype, non_weak.dtype)\n", | |
" dtyp = np.result_type(dt1, dt2)\n", | |
" return dtyp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "cc80988a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"weak + non-weak --> result_type / np_dtype\n", | |
"----------------------------------------------\n", | |
"True + bool --> bool / bool \n", | |
"True + uint8 --> uint8 / uint8 \n", | |
"True + int8 --> int8 / int8 \n", | |
"True + int32 --> int32 / int32 \n", | |
"True + int64 --> int64 / int64 \n", | |
"True + float32 --> float32 / float32 \n", | |
"True + float64 --> float64 / float64 \n", | |
"True + complex64 --> complex64 / complex64 \n", | |
"True + complex128 --> complex128 / complex128\n", | |
"1 + bool --> int64 / int64 \n", | |
"1 + uint8 --> uint8 / uint8 \n", | |
"1 + int8 --> int8 / int8 \n", | |
"1 + int32 --> int32 / int32 \n", | |
"1 + int64 --> int64 / int64 \n", | |
"1 + float32 --> float32 / float32 \n", | |
"1 + float64 --> float64 / float64 \n", | |
"1 + complex64 --> complex64 / complex64 \n", | |
"1 + complex128 --> complex128 / complex128\n", | |
"2.0 + bool --> float64 / float64 \n", | |
"2.0 + uint8 --> float64 / float64 \n", | |
"2.0 + int8 --> float64 / float64 \n", | |
"2.0 + int32 --> float64 / float64 \n", | |
"2.0 + int64 --> float64 / float64 \n", | |
"2.0 + float32 --> float32 / float32 \n", | |
"2.0 + float64 --> float64 / float64 \n", | |
"2.0 + complex64 --> complex64 / complex64 \n", | |
"2.0 + complex128 --> complex128 / complex128\n", | |
"3j + bool --> float64 / complex128\n", | |
"3j + uint8 --> float64 / complex128\n", | |
"3j + int8 --> float64 / complex128\n", | |
"3j + int32 --> float64 / complex128\n", | |
"3j + int64 --> float64 / complex128\n", | |
"3j + float32 --> float32 / complex64 \n", | |
"3j + float64 --> float64 / complex128\n", | |
"3j + complex64 --> complex64 / complex64 \n", | |
"3j + complex128 --> complex128 / complex128\n" | |
] | |
} | |
], | |
"source": [ | |
"weaks = [True, 1, 2.0, 3j]\n", | |
"non_weaks = [np.asarray(True),\n", | |
" np.uint8(1), np.int8(1), np.int32(1), np.int64(1),\n", | |
" np.float32(1), np.float64(1),\n", | |
" np.complex64(1), np.complex128(1)]\n", | |
"\n", | |
"import itertools\n", | |
"print(\"weak + non-weak --> result_type / np_dtype\")\n", | |
"print(\"----------------------------------------------\")\n", | |
"for w, n in itertools.product(weaks, non_weaks):\n", | |
" res = result_type(w, n)\n", | |
" res_np = np.add(w, n).dtype\n", | |
" print(f\"{str(w):<5} + {n.dtype.name:<10} --> {res.name:<10} / {res_np.name:<10}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9ca195c2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.16" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment