Skip to content

Instantly share code, notes, and snippets.

@ev-br
Last active May 16, 2023 10:19
Show Gist options
  • Save ev-br/27dee81aae8e24193db8082ee886f6e4 to your computer and use it in GitHub Desktop.
Save ev-br/27dee81aae8e24193db8082ee886f6e4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "6da0c516",
"metadata": {},
"source": [
"## Type promotion with scalars under NEP50\n",
"\n",
"Assuming that `np.result_type` does type promotion of two arrays, mimic what it does with scalar + array."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1397d81b",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"np._set_promotion_state(\"weak\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "237c6f8b",
"metadata": {},
"outputs": [],
"source": [
"categories = [('bool',),\n",
" ('uint8', 'int8', 'int16', 'int32', 'int64'), # plus other unsigned ints\n",
" ('float16', 'float32', 'float64'), # plus longdouble\n",
" ('complex64', 'complex128')]\n",
"\n",
"\n",
"def category(dtyp):\n",
" for j, cat in enumerate(categories):\n",
" if dtyp in cat:\n",
" return j\n",
" raise ValueError(f\"{dtyp}?\")\n",
"\n",
"\n",
"dtype_for_cat = {0: \"bool\",\n",
" 1: \"int64\", # ignore windows int32\n",
" 2: \"float64\",\n",
" 3: \"complex128\"}"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4b79d2da",
"metadata": {},
"outputs": [],
"source": [
"def pre_promote(weak_dtype, non_weak_dtype):\n",
" \"\"\"Return the dtype for the promoted arguments.\"\"\"\n",
" cat_weak = category(weak_dtype)\n",
" cat_non_weak = category(non_weak_dtype)\n",
" if cat_weak <= cat_non_weak:\n",
" # ignore the 'weak' dtype\n",
" return non_weak_dtype, non_weak_dtype\n",
" else:\n",
" # weak dtype -> default category\n",
" return dtype_for_cat[cat_weak], non_weak_dtype\n",
" \n",
"def result_type(weak, non_weak):\n",
" \"\"\"The dtype of `binop(weak, non_weak)` under NEP 50\"\"\"\n",
" weak, non_weak = np.asarray(weak), np.asarray(non_weak)\n",
" dt1, dt2 = pre_promote(weak.dtype, non_weak.dtype)\n",
" dtyp = np.result_type(dt1, dt2)\n",
" return dtyp"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2c0b937a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dtype('float32')"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result_type(True, np.float32(2))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "cdecdf01",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dtype('float32')"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(True + np.float32(2)).dtype"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "c1f35436",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"weak + non-weak --> result_type / np_dtype\n",
"----------------------------------------------\n",
"True + bool --> bool / bool \n",
"True + uint8 --> uint8 / uint8 \n",
"True + int8 --> int8 / int8 \n",
"True + int32 --> int32 / int32 \n",
"True + int64 --> int64 / int64 \n",
"True + float32 --> float32 / float32 \n",
"True + float64 --> float64 / float64 \n",
"True + complex64 --> complex64 / complex64 \n",
"True + complex128 --> complex128 / complex128\n",
"1 + bool --> int64 / int64 \n",
"1 + uint8 --> uint8 / uint8 \n",
"1 + int8 --> int8 / int8 \n",
"1 + int32 --> int32 / int32 \n",
"1 + int64 --> int64 / int64 \n",
"1 + float32 --> float32 / float32 \n",
"1 + float64 --> float64 / float64 \n",
"1 + complex64 --> complex64 / complex64 \n",
"1 + complex128 --> complex128 / complex128\n",
"2.0 + bool --> float64 / float64 \n",
"2.0 + uint8 --> float64 / float64 \n",
"2.0 + int8 --> float64 / float64 \n",
"2.0 + int32 --> float64 / float64 \n",
"2.0 + int64 --> float64 / float64 \n",
"2.0 + float32 --> float32 / float32 \n",
"2.0 + float64 --> float64 / float64 \n",
"2.0 + complex64 --> complex64 / complex64 \n",
"2.0 + complex128 --> complex128 / complex128\n",
"3j + bool --> complex128 / complex128\n",
"3j + uint8 --> complex128 / complex128\n",
"3j + int8 --> complex128 / complex128\n",
"3j + int32 --> complex128 / complex128\n",
"3j + int64 --> complex128 / complex128\n",
"3j + float32 --> complex128 / complex64 \n",
"3j + float64 --> complex128 / complex128\n",
"3j + complex64 --> complex64 / complex64 \n",
"3j + complex128 --> complex128 / complex128\n"
]
}
],
"source": [
"weaks = [True, 1, 2.0, 3j]\n",
"non_weaks = [np.asarray(True),\n",
" np.uint8(1), np.int8(1), np.int32(1), np.int64(1),\n",
" np.float32(1), np.float64(1),\n",
" np.complex64(1), np.complex128(1)]\n",
"\n",
"import itertools\n",
"print(\"weak + non-weak --> result_type / np_dtype\")\n",
"print(\"----------------------------------------------\")\n",
"for w, n in itertools.product(weaks, non_weaks):\n",
" res = result_type(w, n)\n",
" res_np = np.add(w, n).dtype\n",
" print(f\"{str(w):<5} + {n.dtype.name:<10} --> {res.name:<10} / {res_np.name:<10}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7c2c2353",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.float64(0.1) == np.float32(0.1)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "433118f8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"0.1 == np.float32(0.1)"
]
},
{
"cell_type": "markdown",
"id": "dd09a80b",
"metadata": {},
"source": [
"## Use \"inexact\" as a category (floats and complex)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "0acacf8a",
"metadata": {},
"outputs": [],
"source": [
"categories = [('bool',),\n",
" ('uint8', 'int8', 'int16', 'int32', 'int64'), # plus other unsigned ints\n",
" ('float16', 'float32', 'float64', # plus longdouble\n",
" 'complex64', 'complex128')]\n",
"\n",
"\n",
"def category(dtyp):\n",
" for j, cat in enumerate(categories):\n",
" if dtyp in cat:\n",
" return j\n",
" raise ValueError(f\"{dtyp}?\")\n",
"\n",
"\n",
"dtype_for_cat = {0: \"bool\",\n",
" 1: \"int64\", # ignore windows int32\n",
" 2: \"float64\",\n",
" }\n",
"\n",
"\n",
"def pre_promote(weak_dtype, non_weak_dtype):\n",
" \"\"\"Return the dtype for the promoted arguments.\"\"\"\n",
" cat_weak = category(weak_dtype)\n",
" cat_non_weak = category(non_weak_dtype)\n",
" if cat_weak <= cat_non_weak:\n",
" # ignore the 'weak' dtype\n",
" return non_weak_dtype, non_weak_dtype\n",
" else:\n",
" # weak dtype -> default category\n",
" return dtype_for_cat[cat_weak], non_weak_dtype\n",
"\n",
" \n",
"def result_type(weak, non_weak):\n",
" \"\"\"The dtype of `binop(weak, non_weak)` under NEP 50\"\"\"\n",
" weak, non_weak = np.asarray(weak), np.asarray(non_weak)\n",
" dt1, dt2 = pre_promote(weak.dtype, non_weak.dtype)\n",
" dtyp = np.result_type(dt1, dt2)\n",
" return dtyp"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "cc80988a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"weak + non-weak --> result_type / np_dtype\n",
"----------------------------------------------\n",
"True + bool --> bool / bool \n",
"True + uint8 --> uint8 / uint8 \n",
"True + int8 --> int8 / int8 \n",
"True + int32 --> int32 / int32 \n",
"True + int64 --> int64 / int64 \n",
"True + float32 --> float32 / float32 \n",
"True + float64 --> float64 / float64 \n",
"True + complex64 --> complex64 / complex64 \n",
"True + complex128 --> complex128 / complex128\n",
"1 + bool --> int64 / int64 \n",
"1 + uint8 --> uint8 / uint8 \n",
"1 + int8 --> int8 / int8 \n",
"1 + int32 --> int32 / int32 \n",
"1 + int64 --> int64 / int64 \n",
"1 + float32 --> float32 / float32 \n",
"1 + float64 --> float64 / float64 \n",
"1 + complex64 --> complex64 / complex64 \n",
"1 + complex128 --> complex128 / complex128\n",
"2.0 + bool --> float64 / float64 \n",
"2.0 + uint8 --> float64 / float64 \n",
"2.0 + int8 --> float64 / float64 \n",
"2.0 + int32 --> float64 / float64 \n",
"2.0 + int64 --> float64 / float64 \n",
"2.0 + float32 --> float32 / float32 \n",
"2.0 + float64 --> float64 / float64 \n",
"2.0 + complex64 --> complex64 / complex64 \n",
"2.0 + complex128 --> complex128 / complex128\n",
"3j + bool --> float64 / complex128\n",
"3j + uint8 --> float64 / complex128\n",
"3j + int8 --> float64 / complex128\n",
"3j + int32 --> float64 / complex128\n",
"3j + int64 --> float64 / complex128\n",
"3j + float32 --> float32 / complex64 \n",
"3j + float64 --> float64 / complex128\n",
"3j + complex64 --> complex64 / complex64 \n",
"3j + complex128 --> complex128 / complex128\n"
]
}
],
"source": [
"weaks = [True, 1, 2.0, 3j]\n",
"non_weaks = [np.asarray(True),\n",
" np.uint8(1), np.int8(1), np.int32(1), np.int64(1),\n",
" np.float32(1), np.float64(1),\n",
" np.complex64(1), np.complex128(1)]\n",
"\n",
"import itertools\n",
"print(\"weak + non-weak --> result_type / np_dtype\")\n",
"print(\"----------------------------------------------\")\n",
"for w, n in itertools.product(weaks, non_weaks):\n",
" res = result_type(w, n)\n",
" res_np = np.add(w, n).dtype\n",
" print(f\"{str(w):<5} + {n.dtype.name:<10} --> {res.name:<10} / {res_np.name:<10}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ca195c2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment