Skip to content

Instantly share code, notes, and snippets.

@thomasjpfan
Created April 19, 2023 21:25
Show Gist options
  • Save thomasjpfan/513115f8c6265b83c9fe69ec9f02f11a to your computer and use it in GitHub Desktop.
Save thomasjpfan/513115f8c6265b83c9fe69ec9f02f11a to your computer and use it in GitHub Desktop.
Torch Dynamo + numpy_pytorch_interop experiments
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "1dbefadb-5bd4-44b9-b1c8-6fd5047cc139",
"metadata": {},
"source": [
"# Experiments with Torch Dynamo & numpy_pytorch_interop\n",
"This notebook uses:\n",
"\n",
"- PyTorch built on [pytorch/pytorch#95849](https://github.com/pytorch/pytorch/pull/95849)\n",
"- `numpy_pytorch_interop` on [commit c4e4e43](https://github.com/Quansight-Labs/numpy_pytorch_interop/commit/c4e4e4369e864002a842ec377b8feeb758398a5c)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c5e438f8-bdb1-40a7-a001-533f9d9b1876",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import torch._dynamo as dynamo"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "6af8920c-f155-4a17-b8ea-27061798fd12",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"rng = np.random.default_rng(42)\n",
"X_np = rng.standard_normal(size=(100, 100))\n",
"X_torch = torch.as_tensor(X_np)"
]
},
{
"cell_type": "markdown",
"id": "abdff818-5755-4ff0-bfd7-23a10c7e9035",
"metadata": {},
"source": [
"## Expected behavior (With PyTorch)\n",
"This is a function with pure PyTorch, there are no graph breaks with dynamo"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "42e45d0b-f3f8-45e2-98d3-08a55da253c9",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dynamo produced 1 graphs with 0 graph break and 2 ops\n",
" Break reasons: \n",
"\n",
"TorchDynamo compilation metrics:\n",
"Function, Runtimes (s)\n",
"_compile, 0.0065\n",
"OutputGraph.call_user_compiler, 0.0000\n"
]
}
],
"source": [
"def mean_sum_torch(X):\n",
" X_mean = torch.mean(X, 1)\n",
" X_sum = torch.sum(X_mean)\n",
" return X_sum\n",
"\n",
"explaination = dynamo.explain(mean_sum_torch, X_torch)\n",
"print(explaination[-1])"
]
},
{
"cell_type": "markdown",
"id": "3958de69-ad7b-4f2b-b583-418b119f3c4d",
"metadata": {},
"source": [
"## Function with NumPy"
]
},
{
"cell_type": "markdown",
"id": "04ff9cb1-6711-482d-9fc6-1f1d481724e9",
"metadata": {},
"source": [
"### NumPy Input"
]
},
{
"cell_type": "markdown",
"id": "34c3bf3c-f710-40c6-a98f-a549017aca38",
"metadata": {},
"source": [
"Trying to compile with a NumPy into will produce no graphs. (Expected behavior)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "c0620725-e6b3-4d40-8786-b70a07290104",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dynamo produced 0 graphs with -1 graph break and 0 ops\n",
" Break reasons: \n",
"\n",
"TorchDynamo compilation metrics:\n",
"Function, Runtimes (s)\n",
"_compile, 0.0043\n"
]
}
],
"source": [
"def mean_sum_np(X):\n",
" X_mean = np.mean(X, 1)\n",
" X_sum = np.sum(X_mean)\n",
" return X_sum\n",
"\n",
"explaination_np = dynamo.explain(mean_sum_np, X_np)\n",
"print(explaination_np[-1])"
]
},
{
"cell_type": "markdown",
"id": "e07d8bbe-6da0-443e-b3e4-307c24878f7f",
"metadata": {},
"source": [
"### PyTorch Input\n",
"I expected the following to work and go through the machinary in [pytorch/pytorch#95849](https://github.com/pytorch/pytorch/pull/95849), but it errors:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "59208e2b-1bbf-4f27-88fb-3d892acaf561",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "TypeError",
"evalue": "mean() received an invalid combination of arguments - got (dtype=NoneType, out=NoneType, axis=int, ), but expected one of:\n * (*, torch.dtype dtype)\n * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)\n * (tuple of names dim, bool keepdim, *, torch.dtype dtype)\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[27], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mdynamo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexplain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmean_sum_np\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_torch\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/mambaforge/envs/pytorch1/lib/python3.10/unittest/mock.py:1379\u001b[0m, in \u001b[0;36m_patch.decorate_callable.<locals>.patched\u001b[0;34m(*args, **keywargs)\u001b[0m\n\u001b[1;32m 1374\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 1375\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpatched\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkeywargs):\n\u001b[1;32m 1376\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoration_helper(patched,\n\u001b[1;32m 1377\u001b[0m args,\n\u001b[1;32m 1378\u001b[0m keywargs) \u001b[38;5;28;01mas\u001b[39;00m (newargs, newkeywargs):\n\u001b[0;32m-> 1379\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnewargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnewkeywargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:564\u001b[0m, in \u001b[0;36mexplain\u001b[0;34m(f, *args, **kwargs)\u001b[0m\n\u001b[1;32m 558\u001b[0m opt_f \u001b[38;5;241m=\u001b[39m optimize(\n\u001b[1;32m 559\u001b[0m dynamo_graph_accumulating_compiler,\n\u001b[1;32m 560\u001b[0m nopython\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 561\u001b[0m guard_export_fn\u001b[38;5;241m=\u001b[39mguard_export_print,\n\u001b[1;32m 562\u001b[0m )(f)\n\u001b[1;32m 563\u001b[0m \u001b[38;5;66;03m# TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.\u001b[39;00m\n\u001b[0;32m--> 564\u001b[0m \u001b[43mopt_f\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 566\u001b[0m graph_count \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(graphs)\n\u001b[1;32m 568\u001b[0m \u001b[38;5;66;03m# For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it.\u001b[39;00m\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:253\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 251\u001b[0m dynamic_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m 252\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 253\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 255\u001b[0m set_eval_frame(prior)\n",
"Cell \u001b[0;32mIn[26], line 2\u001b[0m, in \u001b[0;36mmean_sum_np\u001b[0;34m(X)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmean_sum_np\u001b[39m(X):\n\u001b[0;32m----> 2\u001b[0m X_mean \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmean(X, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 3\u001b[0m X_sum \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msum(X_mean)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_sum\n",
"File \u001b[0;32m<__array_function__ internals>:200\u001b[0m, in \u001b[0;36mmean\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
"File \u001b[0;32m~/mambaforge/envs/pytorch1/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3462\u001b[0m, in \u001b[0;36mmean\u001b[0;34m(a, axis, dtype, out, keepdims, where)\u001b[0m\n\u001b[1;32m 3460\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 3461\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3462\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3464\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _methods\u001b[38;5;241m.\u001b[39m_mean(a, axis\u001b[38;5;241m=\u001b[39maxis, dtype\u001b[38;5;241m=\u001b[39mdtype,\n\u001b[1;32m 3465\u001b[0m out\u001b[38;5;241m=\u001b[39mout, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
"\u001b[0;31mTypeError\u001b[0m: mean() received an invalid combination of arguments - got (dtype=NoneType, out=NoneType, axis=int, ), but expected one of:\n * (*, torch.dtype dtype)\n * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)\n * (tuple of names dim, bool keepdim, *, torch.dtype dtype)\n"
]
}
],
"source": [
"_ = dynamo.explain(mean_sum_np, X_torch)"
]
},
{
"cell_type": "markdown",
"id": "d0d46a77-e9d5-455a-8333-56b503ff0e1b",
"metadata": {},
"source": [
"At a glance, `np.mean` is running through `__array_function__` and not reaching the code in [pytorch/pytorch#95849](https://github.com/pytorch/pytorch/pull/95849)."
]
},
{
"cell_type": "markdown",
"id": "acdf3f93-66ca-4943-a6c4-dae76411218a",
"metadata": {},
"source": [
"Compiling and then passing in a torch tensor gives the same error:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "c365fc32-8341-4e67-a051-c6a2865ed3d9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"mean_sum_np_compile = torch.compile(mean_sum_np)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "1956bf4e-6f3b-4329-83f8-50aeafcf4b87",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "TypeError",
"evalue": "mean() received an invalid combination of arguments - got (dtype=NoneType, out=NoneType, axis=int, ), but expected one of:\n * (*, torch.dtype dtype)\n * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)\n * (tuple of names dim, bool keepdim, *, torch.dtype dtype)\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmean_sum_np_compile\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_torch\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:253\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 251\u001b[0m dynamic_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m 252\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 253\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 255\u001b[0m set_eval_frame(prior)\n",
"Cell \u001b[0;32mIn[26], line 2\u001b[0m, in \u001b[0;36mmean_sum_np\u001b[0;34m(X)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmean_sum_np\u001b[39m(X):\n\u001b[0;32m----> 2\u001b[0m X_mean \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmean(X, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 3\u001b[0m X_sum \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msum(X_mean)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_sum\n",
"File \u001b[0;32m<__array_function__ internals>:200\u001b[0m, in \u001b[0;36mmean\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
"File \u001b[0;32m~/mambaforge/envs/pytorch1/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3462\u001b[0m, in \u001b[0;36mmean\u001b[0;34m(a, axis, dtype, out, keepdims, where)\u001b[0m\n\u001b[1;32m 3460\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 3461\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3462\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3464\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _methods\u001b[38;5;241m.\u001b[39m_mean(a, axis\u001b[38;5;241m=\u001b[39maxis, dtype\u001b[38;5;241m=\u001b[39mdtype,\n\u001b[1;32m 3465\u001b[0m out\u001b[38;5;241m=\u001b[39mout, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
"\u001b[0;31mTypeError\u001b[0m: mean() received an invalid combination of arguments - got (dtype=NoneType, out=NoneType, axis=int, ), but expected one of:\n * (*, torch.dtype dtype)\n * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)\n * (tuple of names dim, bool keepdim, *, torch.dtype dtype)\n"
]
}
],
"source": [
"mean_sum_np_compile(X_torch)"
]
},
{
"cell_type": "markdown",
"id": "1c9eafd3-f238-4e82-b57b-2845f5d0ddd5",
"metadata": {},
"source": [
"## Using torch_np\n",
"Here we rewrite the function using `torch_np`. If [pytorch/pytorch#95849](https://github.com/pytorch/pytorch/pull/95849) worked, then I expect at least this result"
]
},
{
"cell_type": "markdown",
"id": "290f59a6-e53e-4a2d-89b7-227f2b36efe8",
"metadata": {},
"source": [
"### NumPy input\n",
"Note that when running `explain` there is a graph break, which is different from the pure PyTorch version. I expected there to be no graph breaks from using `torch_np`."
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "00d7459f-d1ed-404c-bb4a-bf6c24d372c8",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dynamo produced 2 graphs with 1 graph break and 0 ops\n",
" Break reasons: \n",
"\n",
"TorchDynamo compilation metrics:\n",
"Function, Runtimes (s)\n",
"_compile, 0.0054, 0.0045, 0.0086, 0.0024, 0.0010, 0.0023, 0.0189, 0.0014, 0.0038, 0.0095\n",
"OutputGraph.call_user_compiler, 0.0000, 0.0000\n"
]
}
],
"source": [
"import torch_np\n",
"\n",
"def mean_sum_torch_np(X):\n",
" X_mean = torch_np.mean(X, 1)\n",
" X_sum = torch_np.sum(X_mean)\n",
" return X_mean\n",
"\n",
"explaination_np_input = dynamo.explain(mean_sum_torch_np, X_np)\n",
"print(explaination_np_input[-1])"
]
},
{
"cell_type": "markdown",
"id": "46110f76-236f-442a-8bea-0bdc7e23d124",
"metadata": {},
"source": [
"### PyTorch input"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "7308daca-5c74-4904-9b7e-1b530f6ea22d",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dynamo produced 2 graphs with 1 graph break and 0 ops\n",
" Break reasons: \n",
"\n",
"TorchDynamo compilation metrics:\n",
"Function, Runtimes (s)\n",
"_compile, 0.0062, 0.0056, 0.0122, 0.0103, 0.0102, 0.0030, 0.0010, 0.0023, 0.0187, 0.0014, 0.0034, 0.0054\n",
"OutputGraph.call_user_compiler, 0.0000, 0.0000\n"
]
}
],
"source": [
"explaination_torch_input = dynamo.explain(mean_sum_torch_np, X_torch)\n",
"print(explaination_torch_input[-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93fef875-b707-4c83-af52-ad91a8606f0d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "2a6c6686-8c56-4395-8373-fc9191d52b94",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "6ce73184-fd76-48fa-9b1c-561665a35fc4",
"metadata": {},
"source": [
"## Try using NumPy function that does not use `__array_function__`"
]
},
{
"cell_type": "markdown",
"id": "7581f42c-68d2-4c54-9d30-25a5f544df3e",
"metadata": {},
"source": [
"### Pure PyTorch version"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "29b599b8-10e0-4d4c-a5a2-1b67efcc18c9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import torch.linalg\n",
"\n",
"def det_sum_torch(X, Y):\n",
" X_det = torch.linalg.det(X)\n",
" Y_det = torch.linalg.det(Y)\n",
" return X_det + Y_det"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "6b1a97b2-7e9a-4b1f-ab45-7e0012c0776f",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dynamo produced 1 graphs with 0 graph break and 3 ops\n",
" Break reasons: \n",
"\n",
"TorchDynamo compilation metrics:\n",
"Function, Runtimes (s)\n",
"_compile, 0.0089\n",
"OutputGraph.call_user_compiler, 0.0000\n"
]
}
],
"source": [
"explain_det_sum_pure_torch = dynamo.explain(det_sum_torch, X_torch, X_torch)\n",
"print(explain_det_sum_pure_torch[-1])"
]
},
{
"cell_type": "markdown",
"id": "b5e7795e-ef14-4683-adef-37d790ec923a",
"metadata": {},
"source": [
"### NumPy version"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "614bc493-fd43-40e0-b8d6-cf068842e263",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import numpy.linalg\n",
"\n",
"def det_sum_np(X, Y):\n",
" X_det = numpy.linalg.det(X)\n",
" Y_det = numpy.linalg.det(Y)\n",
" return X_det + Y_det"
]
},
{
"cell_type": "markdown",
"id": "02d07c46-ba78-4540-804d-38fb80dbace6",
"metadata": {},
"source": [
"### Torch input"
]
},
{
"cell_type": "markdown",
"id": "8efc1996-ef38-4010-a6bf-7d8581216dbb",
"metadata": {},
"source": [
"Torch input has graph breaks"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "9e86423a-8a45-4019-bb49-7feeaaae7940",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dynamo produced 0 graphs with -1 graph break and 0 ops\n",
" Break reasons: \n",
"\n",
"TorchDynamo compilation metrics:\n",
"Function, Runtimes (s)\n",
"_compile, 0.0051, 0.0044\n"
]
}
],
"source": [
"explain_det_sum_torch = dynamo.explain(det_sum_np, X_torch, X_torch)\n",
"print(explain_det_sum_torch[-1])"
]
},
{
"cell_type": "markdown",
"id": "42542a86-eb26-4b8c-be18-ce5b0a960860",
"metadata": {},
"source": [
"### NumPy Input"
]
},
{
"cell_type": "markdown",
"id": "7c40ad85-1618-4a05-bb46-c11a299d6d73",
"metadata": {},
"source": [
"NumPy input has graph breaks"
]
},
{
"cell_type": "code",
"execution_count": 89,
"id": "2bec5342-4364-4359-b305-b83eff335fac",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dynamo produced 0 graphs with -1 graph break and 0 ops\n",
" Break reasons: \n",
"\n",
"TorchDynamo compilation metrics:\n",
"Function, Runtimes (s)\n",
"_compile, 0.0031\n"
]
}
],
"source": [
"explain_det_sum_numpy = dynamo.explain(det_sum_np, X_np, X_np)\n",
"print(explain_det_sum_numpy[-1])"
]
},
{
"cell_type": "markdown",
"id": "f5f9651e-09e8-4d57-ae81-5d855dd972c5",
"metadata": {},
"source": [
"### Using `torch_np` directly"
]
},
{
"cell_type": "markdown",
"id": "0e784a88-299f-4c9c-be80-5dcc718c4755",
"metadata": {},
"source": [
"This raises an error:"
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "2b84be48-3a8e-4b7b-a260-7d2a18a1bb66",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from torch_np import linalg\n",
"\n",
"def det_sum_torch_np(X, Y):\n",
" X_det = linalg.det(X)\n",
" Y_det = linalg.det(Y)\n",
" return X_det + Y_det"
]
},
{
"cell_type": "code",
"execution_count": 91,
"id": "2d4c2735-0ea1-49c1-ac61-a0cb57e9e688",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "InternalTorchDynamoError",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/convert_frame.py:431\u001b[0m, in \u001b[0;36m_compile\u001b[0;34m(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, frame, frame_state)\u001b[0m\n\u001b[1;32m 430\u001b[0m CleanupManager\u001b[38;5;241m.\u001b[39minstance[out_code] \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39mcleanups\n\u001b[0;32m--> 431\u001b[0m check_fn \u001b[38;5;241m=\u001b[39m \u001b[43mCheckFunctionManager\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 432\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 433\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mlocals\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 434\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mglobals\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 435\u001b[0m \u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mguard_fail_fn\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 436\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 438\u001b[0m guarded_code \u001b[38;5;241m=\u001b[39m GuardedCode(out_code, check_fn\u001b[38;5;241m.\u001b[39mcheck_fn)\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/guards.py:688\u001b[0m, in \u001b[0;36mCheckFunctionManager.__init__\u001b[0;34m(self, output_graph, f_locals, f_globals, guard_fail_fn)\u001b[0m\n\u001b[1;32m 687\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[0;32m--> 688\u001b[0m \u001b[43mguard\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlocal_builder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_builder\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 689\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheck_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompile_check_fn(\n\u001b[1;32m 690\u001b[0m local_builder, global_builder, guards, guard_fail_fn\n\u001b[1;32m 691\u001b[0m )\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_guards.py:184\u001b[0m, in \u001b[0;36mGuard.create\u001b[0;34m(self, local_builder, global_builder)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate\u001b[39m(\u001b[38;5;28mself\u001b[39m, local_builder: GuardBuilderBase, global_builder: GuardBuilderBase):\n\u001b[0;32m--> 184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msource\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlocal_builder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_builder\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/guards.py:189\u001b[0m, in \u001b[0;36mGuardBuilder.TYPE_MATCH\u001b[0;34m(self, guard)\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mTYPE_MATCH\u001b[39m(\u001b[38;5;28mself\u001b[39m, guard: Guard):\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# ___check_type_id is same as `id(type(x)) == y`\u001b[39;00m\n\u001b[0;32m--> 189\u001b[0m t \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mguard\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 190\u001b[0m obj_id \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid_ref(t)\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/guards.py:165\u001b[0m, in \u001b[0;36mGuardBuilder.get\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget\u001b[39m(\u001b[38;5;28mself\u001b[39m, name: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 165\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43meval\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscope\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCLOSURE_VARS\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m<string>:1\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined\n\n\nYou can suppress this exception and fall back to eager by setting:\n torch._dynamo.config.suppress_errors = True\n",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mInternalTorchDynamoError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[91], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m explain_det_sum_torch_np \u001b[38;5;241m=\u001b[39m \u001b[43mdynamo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexplain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdet_sum_torch_np\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_torch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX_torch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(explain_det_sum_torch_np[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n",
"File \u001b[0;32m~/mambaforge/envs/pytorch1/lib/python3.10/unittest/mock.py:1379\u001b[0m, in \u001b[0;36m_patch.decorate_callable.<locals>.patched\u001b[0;34m(*args, **keywargs)\u001b[0m\n\u001b[1;32m 1374\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 1375\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpatched\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkeywargs):\n\u001b[1;32m 1376\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoration_helper(patched,\n\u001b[1;32m 1377\u001b[0m args,\n\u001b[1;32m 1378\u001b[0m keywargs) \u001b[38;5;28;01mas\u001b[39;00m (newargs, newkeywargs):\n\u001b[0;32m-> 1379\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnewargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnewkeywargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:564\u001b[0m, in \u001b[0;36mexplain\u001b[0;34m(f, *args, **kwargs)\u001b[0m\n\u001b[1;32m 558\u001b[0m opt_f \u001b[38;5;241m=\u001b[39m optimize(\n\u001b[1;32m 559\u001b[0m dynamo_graph_accumulating_compiler,\n\u001b[1;32m 560\u001b[0m nopython\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 561\u001b[0m guard_export_fn\u001b[38;5;241m=\u001b[39mguard_export_print,\n\u001b[1;32m 562\u001b[0m )(f)\n\u001b[1;32m 563\u001b[0m \u001b[38;5;66;03m# TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.\u001b[39;00m\n\u001b[0;32m--> 564\u001b[0m \u001b[43mopt_f\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 566\u001b[0m graph_count \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(graphs)\n\u001b[1;32m 568\u001b[0m \u001b[38;5;66;03m# For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it.\u001b[39;00m\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:253\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 251\u001b[0m dynamic_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m 252\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 253\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 255\u001b[0m set_eval_frame(prior)\n",
"Cell \u001b[0;32mIn[90], line 4\u001b[0m, in \u001b[0;36mdet_sum_torch_np\u001b[0;34m(X, Y)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdet_sum_torch_np\u001b[39m(X, Y):\n\u001b[0;32m----> 4\u001b[0m X_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(X)\n\u001b[1;32m 5\u001b[0m Y_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(Y)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_det \u001b[38;5;241m+\u001b[39m Y_det\n",
"Cell \u001b[0;32mIn[90], line 5\u001b[0m, in \u001b[0;36m<resume in det_sum_torch_np>\u001b[0;34m(___stack0, Y)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdet_sum_torch_np\u001b[39m(X, Y):\n\u001b[1;32m 4\u001b[0m X_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(X)\n\u001b[0;32m----> 5\u001b[0m Y_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(Y)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_det \u001b[38;5;241m+\u001b[39m Y_det\n",
"Cell \u001b[0;32mIn[90], line 6\u001b[0m, in \u001b[0;36m<resume in det_sum_torch_np>\u001b[0;34m(___stack0, X_det)\u001b[0m\n\u001b[1;32m 4\u001b[0m X_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(X)\n\u001b[1;32m 5\u001b[0m Y_det \u001b[38;5;241m=\u001b[39m linalg\u001b[38;5;241m.\u001b[39mdet(Y)\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_det \u001b[38;5;241m+\u001b[39m Y_det\n",
"File \u001b[0;32m~/mambaforge/envs/pytorch1/lib/python3.10/site-packages/torch_np/_normalizations.py:198\u001b[0m, in \u001b[0;36mnormalizer.<locals>.normalizer_inner.<locals>.wrapped\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 186\u001b[0m args \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28mtuple\u001b[39m(\n\u001b[1;32m 188\u001b[0m maybe_normalize(arg, parm)\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[38;5;241m+\u001b[39m args[\u001b[38;5;28mlen\u001b[39m(params\u001b[38;5;241m.\u001b[39mvalues()) :]\n\u001b[1;32m 192\u001b[0m )\n\u001b[1;32m 194\u001b[0m kwds \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 195\u001b[0m name: maybe_normalize(arg, params[name]) \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m params \u001b[38;5;28;01melse\u001b[39;00m arg\n\u001b[1;32m 196\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, arg \u001b[38;5;129;01min\u001b[39;00m kwds\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 197\u001b[0m }\n\u001b[0;32m--> 198\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mout\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m params:\n\u001b[1;32m 201\u001b[0m out \u001b[38;5;241m=\u001b[39m sig\u001b[38;5;241m.\u001b[39mbind(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\u001b[38;5;241m.\u001b[39marguments\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mout\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/eval_frame.py:402\u001b[0m, in \u001b[0;36mcatch_errors_wrapper.<locals>.catch_errors\u001b[0;34m(frame, cache_size, frame_state)\u001b[0m\n\u001b[1;32m 399\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m hijacked_callback(frame, cache_size, hooks, frame_state)\n\u001b[1;32m 401\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m compile_lock:\n\u001b[0;32m--> 402\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcallback\u001b[49m\u001b[43m(\u001b[49m\u001b[43mframe\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcache_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mframe_state\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/convert_frame.py:480\u001b[0m, in \u001b[0;36mconvert_frame.<locals>._convert_frame\u001b[0;34m(frame, cache_size, hooks, frame_state)\u001b[0m\n\u001b[1;32m 478\u001b[0m counters[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mframes\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtotal\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 480\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43minner_convert\u001b[49m\u001b[43m(\u001b[49m\u001b[43mframe\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcache_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mframe_state\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 481\u001b[0m counters[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mframes\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mok\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/convert_frame.py:117\u001b[0m, in \u001b[0;36mwrap_convert_context.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 115\u001b[0m cleanup \u001b[38;5;241m=\u001b[39m setup_compile_debug()\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 119\u001b[0m cleanup\u001b[38;5;241m.\u001b[39mclose()\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/convert_frame.py:318\u001b[0m, in \u001b[0;36mconvert_frame_assert.<locals>._convert_frame_assert\u001b[0;34m(frame, cache_size, hooks, frame_state)\u001b[0m\n\u001b[1;32m 303\u001b[0m initial_deterministic_algorithms_state \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 304\u001b[0m torch\u001b[38;5;241m.\u001b[39mare_deterministic_algorithms_enabled()\n\u001b[1;32m 305\u001b[0m )\n\u001b[1;32m 307\u001b[0m signpost_event(\n\u001b[1;32m 308\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdynamo\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 309\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_convert_frame_assert._compile\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 315\u001b[0m },\n\u001b[1;32m 316\u001b[0m )\n\u001b[0;32m--> 318\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_compile\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 319\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 320\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_globals\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_locals\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 322\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_builtins\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 323\u001b[0m \u001b[43m \u001b[49m\u001b[43mcompiler_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 324\u001b[0m \u001b[43m \u001b[49m\u001b[43mone_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 325\u001b[0m \u001b[43m \u001b[49m\u001b[43mexport\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 326\u001b[0m \u001b[43m \u001b[49m\u001b[43mexport_constraints\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 327\u001b[0m \u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 328\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 329\u001b[0m \u001b[43m \u001b[49m\u001b[43mframe_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mframe_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 330\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/utils.py:177\u001b[0m, in \u001b[0;36mdynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mrecord_function(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (dynamo_timed)\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 176\u001b[0m t0 \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 177\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 178\u001b[0m time_spent \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime() \u001b[38;5;241m-\u001b[39m t0\n\u001b[1;32m 179\u001b[0m \u001b[38;5;66;03m# print(f\"Dynamo timer: key={key}, latency={latency:.2f} sec\")\u001b[39;00m\n",
"File \u001b[0;32m~/Desktop/pytorch1/torch/_dynamo/convert_frame.py:468\u001b[0m, in \u001b[0;36m_compile\u001b[0;34m(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, frame, frame_state)\u001b[0m\n\u001b[1;32m 466\u001b[0m exception_handler(e, code, frame)\n\u001b[1;32m 467\u001b[0m \u001b[38;5;66;03m# TODO: Why??? Why not raise the original exception as is\u001b[39;00m\n\u001b[0;32m--> 468\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InternalTorchDynamoError() \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n",
"\u001b[0;31mInternalTorchDynamoError\u001b[0m: "
]
}
],
"source": [
"explain_det_sum_torch_np = dynamo.explain(det_sum_torch_np, X_torch, X_torch)\n",
"print(explain_det_sum_torch_np[-1])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch1 (python3)",
"language": "python",
"name": "conda-env-pytorch1-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment