Created
April 19, 2023 21:25
-
-
Save thomasjpfan/513115f8c6265b83c9fe69ec9f02f11a to your computer and use it in GitHub Desktop.
Torch Dynamo + numpy_pytorch_interop experiments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "1dbefadb-5bd4-44b9-b1c8-6fd5047cc139", | |
"metadata": {}, | |
"source": [ | |
"# Experiments with Torch Dynamo & numpy_pytorch_interop\n", | |
"This notebook uses:\n", | |
"\n", | |
"- PyTorch built on [pytorch/pytorch#95849](https://github.com/pytorch/pytorch/pull/95849)\n", | |
"- `numpy_pytorch_interop` on [commit c4e4e43](https://github.com/Quansight-Labs/numpy_pytorch_interop/commit/c4e4e4369e864002a842ec377b8feeb758398a5c)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "c5e438f8-bdb1-40a7-a001-533f9d9b1876", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import numpy as np\n", | |
"import torch._dynamo as dynamo" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "6af8920c-f155-4a17-b8ea-27061798fd12", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"rng = np.random.default_rng(42)\n", | |
"X_np = rng.standard_normal(size=(100, 100))\n", | |
"X_torch = torch.as_tensor(X_np)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "abdff818-5755-4ff0-bfd7-23a10c7e9035", | |
"metadata": {}, | |
"source": [ | |
"## Expected behavior (With PyTorch)\n", | |
"This is a function with pure PyTorch, there are no graph breaks with dynamo" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "42e45d0b-f3f8-45e2-98d3-08a55da253c9", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Dynamo produced 1 graphs with 0 graph break and 2 ops\n", | |
" Break reasons: \n", | |
"\n", | |
"TorchDynamo compilation metrics:\n", | |
"Function, Runtimes (s)\n", | |
"_compile, 0.0065\n", | |
"OutputGraph.call_user_compiler, 0.0000\n" | |
] | |
} | |
], | |
"source": [ | |
"def mean_sum_torch(X):\n", | |
" X_mean = torch.mean(X, 1)\n", | |
" X_sum = torch.sum(X_mean)\n", | |
" return X_sum\n", | |
"\n", | |
"explaination = dynamo.explain(mean_sum_torch, X_torch)\n", | |
"print(explaination[-1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3958de69-ad7b-4f2b-b583-418b119f3c4d", | |
"metadata": {}, | |
"source": [ | |
"## Function with NumPy" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "04ff9cb1-6711-482d-9fc6-1f1d481724e9", | |
"metadata": {}, | |
"source": [ | |
"### NumPy Input" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "34c3bf3c-f710-40c6-a98f-a549017aca38", | |
"metadata": {}, | |
"source": [ | |
"Trying to compile with a NumPy into will produce no graphs. (Expected behavior)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"id": "c0620725-e6b3-4d40-8786-b70a07290104", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Dynamo produced 0 graphs with -1 graph break and 0 ops\n", | |
" Break reasons: \n", | |
"\n", | |
"TorchDynamo compilation metrics:\n", | |
"Function, Runtimes (s)\n", | |
"_compile, 0.0043\n" | |
] | |
} | |
], | |
"source": [ | |
"def mean_sum_np(X):\n", | |
" X_mean = np.mean(X, 1)\n", | |
" X_sum = np.sum(X_mean)\n", | |
" return X_sum\n", | |
"\n", | |
"explaination_np = dynamo.explain(mean_sum_np, X_np)\n", | |
"print(explaination_np[-1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e07d8bbe-6da0-443e-b3e4-307c24878f7f", | |
"metadata": {}, | |
"source": [ | |
"### PyTorch Input\n", | |
"I expected the following to work and go through the machinary in [pytorch/pytorch#95849](https://github.com/pytorch/pytorch/pull/95849), but it errors:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "59208e2b-1bbf-4f27-88fb-3d892acaf561", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"ename": "TypeError", | |
"evalue": "mean() received an invalid combination of arguments - got (dtype=NoneType, out=NoneType, axis=int, ), but expected one of:\n * (*, torch.dtype dtype)\n * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)\n * (tuple of names dim, bool keepdim, *, torch.dtype dtype)\n", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[27], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mdynamo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexplain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmean_sum_np\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_torch\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/mambaforge/envs/pytorch1/lib/python3.10/unittest/mock.py:1379\u001b[0m, in \u001b[0;36m_patch.decorate_callable.<locals>.patched\u001b[0;34m(*args, **keywargs)\u001b[0m\n\u001b[1;32m 1374\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 1375\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpatched\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkeywargs):\n\u001b[1;32m 1376\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoration_helper(patched,\n\u001b[1;32m 1377\u001b[0m args,\n\u001b[1;32m 1378\u001b[0m keywargs) \u001b[38;5;28;01mas\u001b[39;00m (newargs, newkeywargs):\n\u001b[0;32m-> 1379\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnewargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnewkeywargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:564\u001b[0m, in \u001b[0;36mexplain\u001b[0;34m(f, *args, **kwargs)\u001b[0m\n\u001b[1;32m 558\u001b[0m opt_f \u001b[38;5;241m=\u001b[39m optimize(\n\u001b[1;32m 559\u001b[0m dynamo_graph_accumulating_compiler,\n\u001b[1;32m 560\u001b[0m nopython\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 561\u001b[0m guard_export_fn\u001b[38;5;241m=\u001b[39mguard_export_print,\n\u001b[1;32m 562\u001b[0m )(f)\n\u001b[1;32m 563\u001b[0m \u001b[38;5;66;03m# TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.\u001b[39;00m\n\u001b[0;32m--> 564\u001b[0m \u001b[43mopt_f\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 566\u001b[0m graph_count \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(graphs)\n\u001b[1;32m 568\u001b[0m \u001b[38;5;66;03m# For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it.\u001b[39;00m\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:253\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 251\u001b[0m dynamic_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m 252\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 253\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 255\u001b[0m set_eval_frame(prior)\n", | |
"Cell \u001b[0;32mIn[26], line 2\u001b[0m, in \u001b[0;36mmean_sum_np\u001b[0;34m(X)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmean_sum_np\u001b[39m(X):\n\u001b[0;32m----> 2\u001b[0m X_mean \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmean(X, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 3\u001b[0m X_sum \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msum(X_mean)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_sum\n", | |
"File \u001b[0;32m<__array_function__ internals>:200\u001b[0m, in \u001b[0;36mmean\u001b[0;34m(*args, **kwargs)\u001b[0m\n", | |
"File \u001b[0;32m~/mambaforge/envs/pytorch1/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3462\u001b[0m, in \u001b[0;36mmean\u001b[0;34m(a, axis, dtype, out, keepdims, where)\u001b[0m\n\u001b[1;32m 3460\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 3461\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3462\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3464\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _methods\u001b[38;5;241m.\u001b[39m_mean(a, axis\u001b[38;5;241m=\u001b[39maxis, dtype\u001b[38;5;241m=\u001b[39mdtype,\n\u001b[1;32m 3465\u001b[0m out\u001b[38;5;241m=\u001b[39mout, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", | |
"\u001b[0;31mTypeError\u001b[0m: mean() received an invalid combination of arguments - got (dtype=NoneType, out=NoneType, axis=int, ), but expected one of:\n * (*, torch.dtype dtype)\n * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)\n * (tuple of names dim, bool keepdim, *, torch.dtype dtype)\n" | |
] | |
} | |
], | |
"source": [ | |
"_ = dynamo.explain(mean_sum_np, X_torch)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d0d46a77-e9d5-455a-8333-56b503ff0e1b", | |
"metadata": {}, | |
"source": [ | |
"At a glance, `np.mean` is running through `__array_function__` and not reaching the code in [pytorch/pytorch#95849](https://github.com/pytorch/pytorch/pull/95849)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "acdf3f93-66ca-4943-a6c4-dae76411218a", | |
"metadata": {}, | |
"source": [ | |
"Compiling and then passing in a torch tensor gives the same error:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "c365fc32-8341-4e67-a051-c6a2865ed3d9", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"mean_sum_np_compile = torch.compile(mean_sum_np)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "1956bf4e-6f3b-4329-83f8-50aeafcf4b87", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"ename": "TypeError", | |
"evalue": "mean() received an invalid combination of arguments - got (dtype=NoneType, out=NoneType, axis=int, ), but expected one of:\n * (*, torch.dtype dtype)\n * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)\n * (tuple of names dim, bool keepdim, *, torch.dtype dtype)\n", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmean_sum_np_compile\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_torch\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:253\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 251\u001b[0m dynamic_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m 252\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 253\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 255\u001b[0m set_eval_frame(prior)\n", | |
"Cell \u001b[0;32mIn[26], line 2\u001b[0m, in \u001b[0;36mmean_sum_np\u001b[0;34m(X)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmean_sum_np\u001b[39m(X):\n\u001b[0;32m----> 2\u001b[0m X_mean \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmean(X, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 3\u001b[0m X_sum \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msum(X_mean)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_sum\n", | |
"File \u001b[0;32m<__array_function__ internals>:200\u001b[0m, in \u001b[0;36mmean\u001b[0;34m(*args, **kwargs)\u001b[0m\n", | |
"File \u001b[0;32m~/mambaforge/envs/pytorch1/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3462\u001b[0m, in \u001b[0;36mmean\u001b[0;34m(a, axis, dtype, out, keepdims, where)\u001b[0m\n\u001b[1;32m 3460\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 3461\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3462\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3464\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _methods\u001b[38;5;241m.\u001b[39m_mean(a, axis\u001b[38;5;241m=\u001b[39maxis, dtype\u001b[38;5;241m=\u001b[39mdtype,\n\u001b[1;32m 3465\u001b[0m out\u001b[38;5;241m=\u001b[39mout, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", | |
"\u001b[0;31mTypeError\u001b[0m: mean() received an invalid combination of arguments - got (dtype=NoneType, out=NoneType, axis=int, ), but expected one of:\n * (*, torch.dtype dtype)\n * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)\n * (tuple of names dim, bool keepdim, *, torch.dtype dtype)\n" | |
] | |
} | |
], | |
"source": [ | |
"mean_sum_np_compile(X_torch)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1c9eafd3-f238-4e82-b57b-2845f5d0ddd5", | |
"metadata": {}, | |
"source": [ | |
"## Using torch_np\n", | |
"Here we rewrite the function using `torch_np`. If [pytorch/pytorch#95849](https://github.com/pytorch/pytorch/pull/95849) worked, then I expect at least this result" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "290f59a6-e53e-4a2d-89b7-227f2b36efe8", | |
"metadata": {}, | |
"source": [ | |
"### NumPy input\n", | |
"Note that when running `explain` there is a graph break, which is different from the pure PyTorch version. I expected there to be no graph breaks from using `torch_np`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"id": "00d7459f-d1ed-404c-bb4a-bf6c24d372c8", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Dynamo produced 2 graphs with 1 graph break and 0 ops\n", | |
" Break reasons: \n", | |
"\n", | |
"TorchDynamo compilation metrics:\n", | |
"Function, Runtimes (s)\n", | |
"_compile, 0.0054, 0.0045, 0.0086, 0.0024, 0.0010, 0.0023, 0.0189, 0.0014, 0.0038, 0.0095\n", | |
"OutputGraph.call_user_compiler, 0.0000, 0.0000\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch_np\n", | |
"\n", | |
"def mean_sum_torch_np(X):\n", | |
" X_mean = torch_np.mean(X, 1)\n", | |
" X_sum = torch_np.sum(X_mean)\n", | |
" return X_mean\n", | |
"\n", | |
"explaination_np_input = dynamo.explain(mean_sum_torch_np, X_np)\n", | |
"print(explaination_np_input[-1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "46110f76-236f-442a-8bea-0bdc7e23d124", | |
"metadata": {}, | |
"source": [ | |
"### PyTorch input" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "7308daca-5c74-4904-9b7e-1b530f6ea22d", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Dynamo produced 2 graphs with 1 graph break and 0 ops\n", | |
" Break reasons: \n", | |
"\n", | |
"TorchDynamo compilation metrics:\n", | |
"Function, Runtimes (s)\n", | |
"_compile, 0.0062, 0.0056, 0.0122, 0.0103, 0.0102, 0.0030, 0.0010, 0.0023, 0.0187, 0.0014, 0.0034, 0.0054\n", | |
"OutputGraph.call_user_compiler, 0.0000, 0.0000\n" | |
] | |
} | |
], | |
"source": [ | |
"explaination_torch_input = dynamo.explain(mean_sum_torch_np, X_torch)\n", | |
"print(explaination_torch_input[-1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "93fef875-b707-4c83-af52-ad91a8606f0d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2a6c6686-8c56-4395-8373-fc9191d52b94", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6ce73184-fd76-48fa-9b1c-561665a35fc4", | |
"metadata": {}, | |
"source": [ | |
"## Try using NumPy function that does not use `__array_function__`" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7581f42c-68d2-4c54-9d30-25a5f544df3e", | |
"metadata": {}, | |
"source": [ | |
"### Pure PyTorch version" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"id": "29b599b8-10e0-4d4c-a5a2-1b67efcc18c9", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch.linalg\n", | |
"\n", | |
"def det_sum_torch(X, Y):\n", | |
" X_det = torch.linalg.det(X)\n", | |
" Y_det = torch.linalg.det(Y)\n", | |
" return X_det + Y_det" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 85, | |
"id": "6b1a97b2-7e9a-4b1f-ab45-7e0012c0776f", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Dynamo produced 1 graphs with 0 graph break and 3 ops\n", | |
" Break reasons: \n", | |
"\n", | |
"TorchDynamo compilation metrics:\n", | |
"Function, Runtimes (s)\n", | |
"_compile, 0.0089\n", | |
"OutputGraph.call_user_compiler, 0.0000\n" | |
] | |
} | |
], | |
"source": [ | |
"explain_det_sum_pure_torch = dynamo.explain(det_sum_torch, X_torch, X_torch)\n", | |
"print(explain_det_sum_pure_torch[-1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b5e7795e-ef14-4683-adef-37d790ec923a", | |
"metadata": {}, | |
"source": [ | |
"### NumPy version" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"id": "614bc493-fd43-40e0-b8d6-cf068842e263", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy.linalg\n", | |
"\n", | |
"def det_sum_np(X, Y):\n", | |
" X_det = numpy.linalg.det(X)\n", | |
" Y_det = numpy.linalg.det(Y)\n", | |
" return X_det + Y_det" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "02d07c46-ba78-4540-804d-38fb80dbace6", | |
"metadata": {}, | |
"source": [ | |
"### Torch input" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8efc1996-ef38-4010-a6bf-7d8581216dbb", | |
"metadata": {}, | |
"source": [ | |
"Torch input has graph breaks" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 88, | |
"id": "9e86423a-8a45-4019-bb49-7feeaaae7940", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Dynamo produced 0 graphs with -1 graph break and 0 ops\n", | |
" Break reasons: \n", | |
"\n", | |
"TorchDynamo compilation metrics:\n", | |
"Function, Runtimes (s)\n", | |
"_compile, 0.0051, 0.0044\n" | |
] | |
} | |
], | |
"source": [ | |
"explain_det_sum_torch = dynamo.explain(det_sum_np, X_torch, X_torch)\n", | |
"print(explain_det_sum_torch[-1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "42542a86-eb26-4b8c-be18-ce5b0a960860", | |
"metadata": {}, | |
"source": [ | |
"### NumPy Input" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7c40ad85-1618-4a05-bb46-c11a299d6d73", | |
"metadata": {}, | |
"source": [ | |
"NumPy input has graph breaks" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 89, | |
"id": "2bec5342-4364-4359-b305-b83eff335fac", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Dynamo produced 0 graphs with -1 graph break and 0 ops\n", | |
" Break reasons: \n", | |
"\n", | |
"TorchDynamo compilation metrics:\n", | |
"Function, Runtimes (s)\n", | |
"_compile, 0.0031\n" | |
] | |
} | |
], | |
"source": [ | |
"explain_det_sum_numpy = dynamo.explain(det_sum_np, X_np, X_np)\n", | |
"print(explain_det_sum_numpy[-1])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f5f9651e-09e8-4d57-ae81-5d855dd972c5", | |
"metadata": {}, | |
"source": [ | |
"### Using `torch_np` directly" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0e784a88-299f-4c9c-be80-5dcc718c4755", | |
"metadata": {}, | |
"source": [ | |
"This raises an error:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 90, | |
"id": "2b84be48-3a8e-4b7b-a260-7d2a18a1bb66", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"from torch_np import linalg\n", | |
"\n", | |
"def det_sum_torch_np(X, Y):\n", | |
" X_det = linalg.det(X)\n", | |
" Y_det = linalg.det(Y)\n", | |
" return X_det + Y_det" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 91, | |
"id": "2d4c2735-0ea1-49c1-ac61-a0cb57e9e688", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"ename": "InternalTorchDynamoError", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/convert_frame.py:431\u001b[0m, in \u001b[0;36m_compile\u001b[0;34m(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, frame, frame_state)\u001b[0m\n\u001b[1;32m 430\u001b[0m CleanupManager\u001b[38;5;241m.\u001b[39minstance[out_code] \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39mcleanups\n\u001b[0;32m--> 431\u001b[0m check_fn \u001b[38;5;241m=\u001b[39m \u001b[43mCheckFunctionManager\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 432\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 433\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mlocals\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 434\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mglobals\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 435\u001b[0m \u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mguard_fail_fn\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 436\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 438\u001b[0m guarded_code \u001b[38;5;241m=\u001b[39m GuardedCode(out_code, check_fn\u001b[38;5;241m.\u001b[39mcheck_fn)\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/guards.py:688\u001b[0m, in \u001b[0;36mCheckFunctionManager.__init__\u001b[0;34m(self, output_graph, f_locals, f_globals, guard_fail_fn)\u001b[0m\n\u001b[1;32m 687\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[0;32m--> 688\u001b[0m \u001b[43mguard\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlocal_builder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_builder\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 689\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheck_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompile_check_fn(\n\u001b[1;32m 690\u001b[0m local_builder, global_builder, guards, guard_fail_fn\n\u001b[1;32m 691\u001b[0m )\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_guards.py:184\u001b[0m, in \u001b[0;36mGuard.create\u001b[0;34m(self, local_builder, global_builder)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate\u001b[39m(\u001b[38;5;28mself\u001b[39m, local_builder: GuardBuilderBase, global_builder: GuardBuilderBase):\n\u001b[0;32m--> 184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msource\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlocal_builder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_builder\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/guards.py:189\u001b[0m, in \u001b[0;36mGuardBuilder.TYPE_MATCH\u001b[0;34m(self, guard)\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mTYPE_MATCH\u001b[39m(\u001b[38;5;28mself\u001b[39m, guard: Guard):\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# ___check_type_id is same as `id(type(x)) == y`\u001b[39;00m\n\u001b[0;32m--> 189\u001b[0m t \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mguard\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 190\u001b[0m obj_id \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid_ref(t)\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/guards.py:165\u001b[0m, in \u001b[0;36mGuardBuilder.get\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget\u001b[39m(\u001b[38;5;28mself\u001b[39m, name: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 165\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43meval\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscope\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCLOSURE_VARS\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m<string>:1\u001b[0m\n", | |
"\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined\n\n\nYou can suppress this exception and fall back to eager by setting:\n torch._dynamo.config.suppress_errors = True\n", | |
"\nThe above exception was the direct cause of the following exception:\n", | |
"\u001b[0;31mInternalTorchDynamoError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[91], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m explain_det_sum_torch_np \u001b[38;5;241m=\u001b[39m \u001b[43mdynamo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexplain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdet_sum_torch_np\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_torch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_torch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(explain_det_sum_torch_np[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n", | |
"File \u001b[0;32m~/mambaforge/envs/pytorch1/lib/python3.10/unittest/mock.py:1379\u001b[0m, in \u001b[0;36m_patch.decorate_callable.<locals>.patched\u001b[0;34m(*args, **keywargs)\u001b[0m\n\u001b[1;32m 1374\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 1375\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpatched\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkeywargs):\n\u001b[1;32m 1376\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoration_helper(patched,\n\u001b[1;32m 1377\u001b[0m args,\n\u001b[1;32m 1378\u001b[0m keywargs) \u001b[38;5;28;01mas\u001b[39;00m (newargs, newkeywargs):\n\u001b[0;32m-> 1379\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnewargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnewkeywargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:564\u001b[0m, in \u001b[0;36mexplain\u001b[0;34m(f, *args, **kwargs)\u001b[0m\n\u001b[1;32m 558\u001b[0m opt_f \u001b[38;5;241m=\u001b[39m optimize(\n\u001b[1;32m 559\u001b[0m dynamo_graph_accumulating_compiler,\n\u001b[1;32m 560\u001b[0m nopython\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 561\u001b[0m guard_export_fn\u001b[38;5;241m=\u001b[39mguard_export_print,\n\u001b[1;32m 562\u001b[0m )(f)\n\u001b[1;32m 563\u001b[0m \u001b[38;5;66;03m# TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.\u001b[39;00m\n\u001b[0;32m--> 564\u001b[0m \u001b[43mopt_f\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 566\u001b[0m graph_count \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(graphs)\n\u001b[1;32m 568\u001b[0m \u001b[38;5;66;03m# For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it.\u001b[39;00m\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:253\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 251\u001b[0m dynamic_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m 252\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 253\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 255\u001b[0m set_eval_frame(prior)\n", | |
"Cell \u001b[0;32mIn[90], line 4\u001b[0m, in \u001b[0;36mdet_sum_torch_np\u001b[0;34m(X, Y)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdet_sum_torch_np\u001b[39m(X, Y):\n\u001b[0;32m----> 4\u001b[0m X_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(X)\n\u001b[1;32m 5\u001b[0m Y_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(Y)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_det \u001b[38;5;241m+\u001b[39m Y_det\n", | |
"Cell \u001b[0;32mIn[90], line 5\u001b[0m, in \u001b[0;36m<resume in det_sum_torch_np>\u001b[0;34m(___stack0, Y)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdet_sum_torch_np\u001b[39m(X, Y):\n\u001b[1;32m 4\u001b[0m X_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(X)\n\u001b[0;32m----> 5\u001b[0m Y_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(Y)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_det \u001b[38;5;241m+\u001b[39m Y_det\n", | |
"Cell \u001b[0;32mIn[90], line 6\u001b[0m, in \u001b[0;36m<resume in det_sum_torch_np>\u001b[0;34m(___stack0, X_det)\u001b[0m\n\u001b[1;32m 4\u001b[0m X_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(X)\n\u001b[1;32m 5\u001b[0m Y_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(Y)\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_det \u001b[38;5;241m+\u001b[39m Y_det\n", | |
"File \u001b[0;32m~/mambaforge/envs/pytorch1/lib/python3.10/site-packages/torch_np/_normalizations.py:198\u001b[0m, in \u001b[0;36mnormalizer.<locals>.normalizer_inner.<locals>.wrapped\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 186\u001b[0m args \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28mtuple\u001b[39m(\n\u001b[1;32m 188\u001b[0m maybe_normalize(arg, parm)\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[38;5;241m+\u001b[39m args[\u001b[38;5;28mlen\u001b[39m(params\u001b[38;5;241m.\u001b[39mvalues()) :]\n\u001b[1;32m 192\u001b[0m )\n\u001b[1;32m 194\u001b[0m kwds \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 195\u001b[0m name: maybe_normalize(arg, params[name]) \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m params \u001b[38;5;28;01melse\u001b[39;00m arg\n\u001b[1;32m 196\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, arg \u001b[38;5;129;01min\u001b[39;00m kwds\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 197\u001b[0m }\n\u001b[0;32m--> 198\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mout\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m params:\n\u001b[1;32m 201\u001b[0m out \u001b[38;5;241m=\u001b[39m sig\u001b[38;5;241m.\u001b[39mbind(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\u001b[38;5;241m.\u001b[39marguments\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mout\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:402\u001b[0m, in \u001b[0;36mcatch_errors_wrapper.<locals>.catch_errors\u001b[0;34m(frame, cache_size, frame_state)\u001b[0m\n\u001b[1;32m 399\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m hijacked_callback(frame, cache_size, hooks, frame_state)\n\u001b[1;32m 401\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m compile_lock:\n\u001b[0;32m--> 402\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcallback\u001b[49m\u001b[43m(\u001b[49m\u001b[43mframe\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcache_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mframe_state\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/convert_frame.py:480\u001b[0m, in \u001b[0;36mconvert_frame.<locals>._convert_frame\u001b[0;34m(frame, cache_size, hooks, frame_state)\u001b[0m\n\u001b[1;32m 478\u001b[0m counters[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mframes\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtotal\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 480\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43minner_convert\u001b[49m\u001b[43m(\u001b[49m\u001b[43mframe\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcache_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mframe_state\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 481\u001b[0m counters[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mframes\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mok\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/convert_frame.py:117\u001b[0m, in \u001b[0;36mwrap_convert_context.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 115\u001b[0m cleanup \u001b[38;5;241m=\u001b[39m setup_compile_debug()\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 119\u001b[0m cleanup\u001b[38;5;241m.\u001b[39mclose()\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/convert_frame.py:318\u001b[0m, in \u001b[0;36mconvert_frame_assert.<locals>._convert_frame_assert\u001b[0;34m(frame, cache_size, hooks, frame_state)\u001b[0m\n\u001b[1;32m 303\u001b[0m initial_deterministic_algorithms_state \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 304\u001b[0m torch\u001b[38;5;241m.\u001b[39mare_deterministic_algorithms_enabled()\n\u001b[1;32m 305\u001b[0m )\n\u001b[1;32m 307\u001b[0m signpost_event(\n\u001b[1;32m 308\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdynamo\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 309\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_convert_frame_assert._compile\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 315\u001b[0m },\n\u001b[1;32m 316\u001b[0m )\n\u001b[0;32m--> 318\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_compile\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 319\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 320\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_globals\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_locals\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 322\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_builtins\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 323\u001b[0m \u001b[43m \u001b[49m\u001b[43mcompiler_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 324\u001b[0m \u001b[43m \u001b[49m\u001b[43mone_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 325\u001b[0m \u001b[43m \u001b[49m\u001b[43mexport\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 326\u001b[0m \u001b[43m \u001b[49m\u001b[43mexport_constraints\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 327\u001b[0m \u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 328\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 329\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mframe_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 330\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/utils.py:177\u001b[0m, in \u001b[0;36mdynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mrecord_function(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (dynamo_timed)\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 176\u001b[0m t0 \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 177\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 178\u001b[0m time_spent \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime() \u001b[38;5;241m-\u001b[39m t0\n\u001b[1;32m 179\u001b[0m \u001b[38;5;66;03m# print(f\"Dynamo timer: key={key}, latency={latency:.2f} sec\")\u001b[39;00m\n", | |
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/convert_frame.py:468\u001b[0m, in \u001b[0;36m_compile\u001b[0;34m(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, frame, frame_state)\u001b[0m\n\u001b[1;32m 466\u001b[0m exception_handler(e, code, frame)\n\u001b[1;32m 467\u001b[0m \u001b[38;5;66;03m# TODO: Why??? Why not raise the original exception as is\u001b[39;00m\n\u001b[0;32m--> 468\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InternalTorchDynamoError() \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n", | |
"\u001b[0;31mInternalTorchDynamoError\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"explain_det_sum_torch_np = dynamo.explain(det_sum_torch_np, X_torch, X_torch)\n", | |
"print(explain_det_sum_torch_np[-1])" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "pytorch1 (python3)", | |
"language": "python", | |
"name": "conda-env-pytorch1-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment