-
-
Save bhavya01/0346b0d47931ba60751dbe79b01268a0 to your computer and use it in GitHub Desktop.
Dynamo multi_tensor_sgd
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ PT_XLA_DEBUG=1 python test_dynamo.py | |
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR | |
I0000 00:00:1715804980.303798 3806764 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 | |
I0000 00:00:1715804980.337946 3806763 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 | |
I0000 00:00:1715804980.343785 3806766 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 | |
I0000 00:00:1715804980.351582 3806765 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 | |
I0000 00:00:1715804980.354106 3806453 service.cc:145] XLA service 0x55e17dd4f190 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: | |
I0000 00:00:1715804980.354143 3806453 service.cc:153] StreamExecutor device (0): Tesla T4, Compute Capability 7.5 | |
I0000 00:00:1715804980.354150 3806453 service.cc:153] StreamExecutor device (1): Tesla T4, Compute Capability 7.5 | |
I0000 00:00:1715804980.354157 3806453 service.cc:153] StreamExecutor device (2): Tesla T4, Compute Capability 7.5 | |
I0000 00:00:1715804980.354170 3806453 service.cc:153] StreamExecutor device (3): Tesla T4, Compute Capability 7.5 | |
I0000 00:00:1715804980.365586 3806453 se_gpu_pjrt_client.cc:853] Using BFC allocator. | |
I0000 00:00:1715804980.365681 3806453 gpu_helpers.cc:107] XLA backend allocating 11731746816 bytes on device 0 for BFCAllocator. | |
I0000 00:00:1715804980.365724 3806453 gpu_helpers.cc:107] XLA backend allocating 11731746816 bytes on device 1 for BFCAllocator. | |
I0000 00:00:1715804980.365759 3806453 gpu_helpers.cc:107] XLA backend allocating 11731746816 bytes on device 2 for BFCAllocator. | |
I0000 00:00:1715804980.365805 3806453 gpu_helpers.cc:107] XLA backend allocating 11731746816 bytes on device 3 for BFCAllocator. | |
I0000 00:00:1715804980.365836 3806453 gpu_helpers.cc:147] XLA backend will use up to 3910582272 bytes on device 0 for CollectiveBFCAllocator. | |
I0000 00:00:1715804980.365862 3806453 gpu_helpers.cc:147] XLA backend will use up to 3910582272 bytes on device 1 for CollectiveBFCAllocator. | |
I0000 00:00:1715804980.365887 3806453 gpu_helpers.cc:147] XLA backend will use up to 3910582272 bytes on device 2 for CollectiveBFCAllocator. | |
I0000 00:00:1715804980.365911 3806453 gpu_helpers.cc:147] XLA backend will use up to 3910582272 bytes on device 3 for CollectiveBFCAllocator. | |
I0000 00:00:1715804980.366071 3806453 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 | |
I0000 00:00:1715804980.368051 3806453 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 | |
I0000 00:00:1715804980.370038 3806453 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 | |
I0000 00:00:1715804980.371979 3806453 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 | |
-------- 0 ----------------- | |
Compilation Analysis: ================================================================================ | |
Compilation Analysis: Compilation Cause | |
Compilation Analysis: mark_step when dynamo processing input graphs | |
Compilation Analysis: Graph Info: | |
Compilation Analysis: Graph Hash: 455b827bd004a389d858839fb4ef3944 | |
Compilation Analysis: Number of Graph Inputs: 8 | |
Compilation Analysis: Number of Graph Outputs: 8 | |
Compilation Analysis: Python Frame Triggered Execution: | |
Compilation Analysis: mark_step (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1056) | |
Compilation Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:542) | |
Compilation Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51) | |
Compilation Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Compilation Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113) | |
Compilation Analysis: forward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:505) | |
Compilation Analysis: apply (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/autograd/function.py:598) | |
Compilation Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Compilation Analysis: .......... | |
Compilation Analysis: -------------------------------------------------------------------------------- | |
Compilation Analysis: ================================================================================ | |
Execution Analysis: ================================================================================ | |
Execution Analysis: Execution Cause | |
Execution Analysis: mark_step when dynamo processing input graphs | |
Execution Analysis: Graph Info: | |
Execution Analysis: Graph Hash: 455b827bd004a389d858839fb4ef3944 | |
Execution Analysis: Number of Graph Inputs: 8 | |
Execution Analysis: Number of Graph Outputs: 8 | |
Execution Analysis: Python Frame Triggered Execution: | |
Execution Analysis: mark_step (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1056) | |
Execution Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:542) | |
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51) | |
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Execution Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113) | |
Execution Analysis: forward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:505) | |
Execution Analysis: apply (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/autograd/function.py:598) | |
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Execution Analysis: .......... | |
Execution Analysis: -------------------------------------------------------------------------------- | |
Execution Analysis: ================================================================================ | |
Compilation Analysis: ================================================================================ | |
Compilation Analysis: Compilation Cause | |
Compilation Analysis: dynamo is compiling a FX graph to HLO | |
Compilation Analysis: Graph Info: | |
Compilation Analysis: Graph Hash: 74da15968ea4ebca0372e98ddca11641 | |
Compilation Analysis: Number of Graph Inputs: 8 | |
Compilation Analysis: Number of Graph Outputs: 8 | |
Compilation Analysis: Python Frame Triggered Execution: | |
Compilation Analysis: extract_graph_helper (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:339) | |
Compilation Analysis: extract_internal (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:374) | |
Compilation Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:621) | |
Compilation Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51) | |
Compilation Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Compilation Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113) | |
Compilation Analysis: forward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:505) | |
Compilation Analysis: apply (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/autograd/function.py:598) | |
Compilation Analysis: .......... | |
Compilation Analysis: -------------------------------------------------------------------------------- | |
Compilation Analysis: ================================================================================ | |
Execution Analysis: ================================================================================ | |
Execution Analysis: Execution Cause | |
Execution Analysis: dynamo is executing a compiled program | |
Execution Analysis: Graph Info: | |
Execution Analysis: Graph Hash: 74da15968ea4ebca0372e98ddca11641 | |
Execution Analysis: Number of Graph Inputs: 8 | |
Execution Analysis: Number of Graph Outputs: 8 | |
Execution Analysis: Python Frame Triggered Execution: | |
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427) | |
Execution Analysis: forward (<eval_with_key>.1:5) | |
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) | |
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) | |
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304) | |
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737) | |
Execution Analysis: _lazy_forward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py:123) | |
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) | |
Execution Analysis: .......... | |
Execution Analysis: -------------------------------------------------------------------------------- | |
Execution Analysis: ================================================================================ | |
Compilation Analysis: ================================================================================ | |
Compilation Analysis: Compilation Cause | |
Compilation Analysis: mark_step when dynamo processing input graphs | |
Compilation Analysis: Graph Info: | |
Compilation Analysis: Graph Hash: 8e6aa44c425d7d283e8e959f91190d1a | |
Compilation Analysis: Number of Graph Inputs: 0 | |
Compilation Analysis: Number of Graph Outputs: 1 | |
Compilation Analysis: Python Frame Triggered Execution: | |
Compilation Analysis: mark_step (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1056) | |
Compilation Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:542) | |
Compilation Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51) | |
Compilation Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Compilation Analysis: inner (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:36) | |
Compilation Analysis: _fn (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451) | |
Compilation Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113) | |
Compilation Analysis: call_compiled_backward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:831) | |
Compilation Analysis: .......... | |
Compilation Analysis: -------------------------------------------------------------------------------- | |
Compilation Analysis: ================================================================================ | |
Execution Analysis: ================================================================================ | |
Execution Analysis: Execution Cause | |
Execution Analysis: mark_step when dynamo processing input graphs | |
Execution Analysis: Graph Info: | |
Execution Analysis: Graph Hash: 8e6aa44c425d7d283e8e959f91190d1a | |
Execution Analysis: Number of Graph Inputs: 0 | |
Execution Analysis: Number of Graph Outputs: 1 | |
Execution Analysis: Python Frame Triggered Execution: | |
Execution Analysis: mark_step (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1056) | |
Execution Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:542) | |
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51) | |
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Execution Analysis: inner (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:36) | |
Execution Analysis: _fn (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451) | |
Execution Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113) | |
Execution Analysis: call_compiled_backward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:831) | |
Execution Analysis: .......... | |
Execution Analysis: -------------------------------------------------------------------------------- | |
Execution Analysis: ================================================================================ | |
Compilation Analysis: ================================================================================ | |
Compilation Analysis: Compilation Cause | |
Compilation Analysis: dynamo is compiling a FX graph to HLO | |
Compilation Analysis: Graph Info: | |
Compilation Analysis: Graph Hash: a9a891709155550a12b9907720b7d210 | |
Compilation Analysis: Number of Graph Inputs: 8 | |
Compilation Analysis: Number of Graph Outputs: 6 | |
Compilation Analysis: Python Frame Triggered Execution: | |
Compilation Analysis: extract_graph_helper (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:339) | |
Compilation Analysis: extract_internal (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:374) | |
Compilation Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:621) | |
Compilation Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51) | |
Compilation Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Compilation Analysis: inner (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:36) | |
Compilation Analysis: _fn (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451) | |
Compilation Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113) | |
Compilation Analysis: .......... | |
Compilation Analysis: -------------------------------------------------------------------------------- | |
Compilation Analysis: ================================================================================ | |
Execution Analysis: ================================================================================ | |
Execution Analysis: Execution Cause | |
Execution Analysis: dynamo is executing a compiled program | |
Execution Analysis: Graph Info: | |
Execution Analysis: Graph Hash: a9a891709155550a12b9907720b7d210 | |
Execution Analysis: Number of Graph Inputs: 8 | |
Execution Analysis: Number of Graph Outputs: 6 | |
Execution Analysis: Python Frame Triggered Execution: | |
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427) | |
Execution Analysis: forward (<eval_with_key>.3:5) | |
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) | |
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) | |
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304) | |
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737) | |
Execution Analysis: _lazy_forward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py:123) | |
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) | |
Execution Analysis: .......... | |
Execution Analysis: -------------------------------------------------------------------------------- | |
Execution Analysis: ================================================================================ | |
-------- 1 ----------------- | |
Execution Analysis: ================================================================================ | |
Execution Analysis: Execution Cause | |
Execution Analysis: dynamo is executing a compiled program | |
Execution Analysis: Graph Info: | |
Execution Analysis: Graph Hash: 74da15968ea4ebca0372e98ddca11641 | |
Execution Analysis: Number of Graph Inputs: 8 | |
Execution Analysis: Number of Graph Outputs: 8 | |
Execution Analysis: Python Frame Triggered Execution: | |
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427) | |
Execution Analysis: forward (<eval_with_key>.1:5) | |
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) | |
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) | |
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304) | |
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737) | |
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53) | |
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Execution Analysis: .......... | |
Execution Analysis: -------------------------------------------------------------------------------- | |
Execution Analysis: ================================================================================ | |
Execution Analysis: ================================================================================ | |
Execution Analysis: Execution Cause | |
Execution Analysis: most likely user code trying to access tensor value before mark_step | |
Execution Analysis: Graph Info: | |
Execution Analysis: Graph Hash: 8e6aa44c425d7d283e8e959f91190d1a | |
Execution Analysis: Number of Graph Inputs: 0 | |
Execution Analysis: Number of Graph Outputs: 1 | |
Execution Analysis: Python Frame Triggered Execution: | |
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:398) | |
Execution Analysis: forward (<eval_with_key>.3:5) | |
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) | |
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) | |
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304) | |
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737) | |
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53) | |
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Execution Analysis: .......... | |
Execution Analysis: -------------------------------------------------------------------------------- | |
Execution Analysis: ================================================================================ | |
Execution Analysis: ================================================================================ | |
Execution Analysis: Execution Cause | |
Execution Analysis: dynamo is executing a compiled program | |
Execution Analysis: Graph Info: | |
Execution Analysis: Graph Hash: a9a891709155550a12b9907720b7d210 | |
Execution Analysis: Number of Graph Inputs: 8 | |
Execution Analysis: Number of Graph Outputs: 6 | |
Execution Analysis: Python Frame Triggered Execution: | |
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427) | |
Execution Analysis: forward (<eval_with_key>.3:5) | |
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) | |
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) | |
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304) | |
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737) | |
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53) | |
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Execution Analysis: .......... | |
Execution Analysis: -------------------------------------------------------------------------------- | |
Execution Analysis: ================================================================================ | |
-------- 2 ----------------- | |
Execution Analysis: ================================================================================ | |
Execution Analysis: Execution Cause | |
Execution Analysis: dynamo is executing a compiled program | |
Execution Analysis: Graph Info: | |
Execution Analysis: Graph Hash: 74da15968ea4ebca0372e98ddca11641 | |
Execution Analysis: Number of Graph Inputs: 8 | |
Execution Analysis: Number of Graph Outputs: 8 | |
Execution Analysis: Python Frame Triggered Execution: | |
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427) | |
Execution Analysis: forward (<eval_with_key>.1:5) | |
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) | |
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) | |
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304) | |
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737) | |
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53) | |
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Execution Analysis: .......... | |
Execution Analysis: -------------------------------------------------------------------------------- | |
Execution Analysis: ================================================================================ | |
Execution Analysis: ================================================================================ | |
Execution Analysis: Execution Cause | |
Execution Analysis: most likely user code trying to access tensor value before mark_step | |
Execution Analysis: Graph Info: | |
Execution Analysis: Graph Hash: 8e6aa44c425d7d283e8e959f91190d1a | |
Execution Analysis: Number of Graph Inputs: 0 | |
Execution Analysis: Number of Graph Outputs: 1 | |
Execution Analysis: Python Frame Triggered Execution: | |
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:398) | |
Execution Analysis: forward (<eval_with_key>.3:5) | |
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) | |
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) | |
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304) | |
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737) | |
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53) | |
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Execution Analysis: .......... | |
Execution Analysis: -------------------------------------------------------------------------------- | |
Execution Analysis: ================================================================================ | |
Execution Analysis: ================================================================================ | |
Execution Analysis: Execution Cause | |
Execution Analysis: dynamo is executing a compiled program | |
Execution Analysis: Graph Info: | |
Execution Analysis: Graph Hash: a9a891709155550a12b9907720b7d210 | |
Execution Analysis: Number of Graph Inputs: 8 | |
Execution Analysis: Number of Graph Outputs: 6 | |
Execution Analysis: Python Frame Triggered Execution: | |
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427) | |
Execution Analysis: forward (<eval_with_key>.3:5) | |
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541) | |
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532) | |
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304) | |
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737) | |
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53) | |
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89) | |
Execution Analysis: .......... | |
Execution Analysis: -------------------------------------------------------------------------------- | |
Execution Analysis: ================================================================================ | |
Done! | |
WARNING:pt-xla-profiler:================================================================================ | |
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance) | |
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access | |
WARNING:pt-xla-profiler:-------------------------------------------------------------------------------- | |
WARNING:pt-xla-profiler:FRAME (count=3): | |
WARNING:pt-xla-profiler:Unlowered Op: "xla_cpu_fallback" | |
WARNING:pt-xla-profiler: | |
WARNING:pt-xla-profiler: | |
WARNING:pt-xla-profiler:FRAME (count=3): | |
WARNING:pt-xla-profiler: _multi_tensor_sgd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/optim/sgd.py:360) | |
WARNING:pt-xla-profiler: sgd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/optim/sgd.py:245) | |
WARNING:pt-xla-profiler: step (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/optim/sgd.py:80) | |
WARNING:pt-xla-profiler: _use_grad (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/optim/optimizer.py:76) | |
WARNING:pt-xla-profiler: wrapper (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/optim/optimizer.py:391) | |
WARNING:pt-xla-profiler: _fn (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451) | |
WARNING:pt-xla-profiler: torch_dynamo_resume_in_train_loop_at_43 (/home/bbahl/test_dynamo.py:44) | |
WARNING:pt-xla-profiler: torch_dynamo_resume_in_train_loop_at_39 (/home/bbahl/test_dynamo.py:43) | |
WARNING:pt-xla-profiler: train_loop (/home/bbahl/test_dynamo.py:39) | |
WARNING:pt-xla-profiler: _fn (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451) | |
WARNING:pt-xla-profiler: <module> (/home/bbahl/test_dynamo.py:55) | |
WARNING:pt-xla-profiler: | |
WARNING:pt-xla-profiler: | |
WARNING:pt-xla-profiler:================================================================================ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.utils.data import DataLoader | |
from torchvision import datasets | |
from torchvision.transforms import ToTensor | |
import torch_xla.core.xla_model as xm | |
training_data = datasets.FashionMNIST( | |
root="data", | |
train=True, | |
download=True, | |
transform=ToTensor() | |
) | |
train_dataloader = DataLoader(training_data, batch_size=1024) | |
class NeuralNetwork(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.flatten = nn.Flatten() | |
self.linear_relu_stack = nn.Sequential( | |
nn.Linear(28*28, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, 10), | |
) | |
def forward(self, x): | |
x = self.flatten(x) | |
logits = self.linear_relu_stack(x) | |
return logits | |
device = xm.xla_device() | |
model = NeuralNetwork().to(device) | |
def train_loop(X, y, model, optimizer): | |
optimizer.zero_grad() | |
loss_fn = nn.CrossEntropyLoss() | |
pred = model(X) | |
loss = loss_fn(pred, y) | |
loss.backward() | |
optimizer.step() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, foreach=True) | |
optimizer.step = torch._dynamo.disable(optimizer.step) | |
train_loop_ = torch.compile(train_loop, backend='openxla') | |
for batch, (X, y) in enumerate(train_dataloader): | |
print(f"-------- {batch} -----------------") | |
X = X.to(device) | |
y = y.to(device) | |
train_loop_(X, y, model, optimizer) | |
if batch == 2: | |
break | |
print("Done!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment