Skip to content

Instantly share code, notes, and snippets.

@bhavya01

bhavya01/output Secret

Created May 15, 2024 20:37
Show Gist options
  • Save bhavya01/0346b0d47931ba60751dbe79b01268a0 to your computer and use it in GitHub Desktop.
Save bhavya01/0346b0d47931ba60751dbe79b01268a0 to your computer and use it in GitHub Desktop.
Dynamo multi_tensor_sgd
$ PT_XLA_DEBUG=1 python test_dynamo.py
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1715804980.303798 3806764 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1715804980.337946 3806763 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1715804980.343785 3806766 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1715804980.351582 3806765 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1715804980.354106 3806453 service.cc:145] XLA service 0x55e17dd4f190 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1715804980.354143 3806453 service.cc:153] StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1715804980.354150 3806453 service.cc:153] StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1715804980.354157 3806453 service.cc:153] StreamExecutor device (2): Tesla T4, Compute Capability 7.5
I0000 00:00:1715804980.354170 3806453 service.cc:153] StreamExecutor device (3): Tesla T4, Compute Capability 7.5
I0000 00:00:1715804980.365586 3806453 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1715804980.365681 3806453 gpu_helpers.cc:107] XLA backend allocating 11731746816 bytes on device 0 for BFCAllocator.
I0000 00:00:1715804980.365724 3806453 gpu_helpers.cc:107] XLA backend allocating 11731746816 bytes on device 1 for BFCAllocator.
I0000 00:00:1715804980.365759 3806453 gpu_helpers.cc:107] XLA backend allocating 11731746816 bytes on device 2 for BFCAllocator.
I0000 00:00:1715804980.365805 3806453 gpu_helpers.cc:107] XLA backend allocating 11731746816 bytes on device 3 for BFCAllocator.
I0000 00:00:1715804980.365836 3806453 gpu_helpers.cc:147] XLA backend will use up to 3910582272 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1715804980.365862 3806453 gpu_helpers.cc:147] XLA backend will use up to 3910582272 bytes on device 1 for CollectiveBFCAllocator.
I0000 00:00:1715804980.365887 3806453 gpu_helpers.cc:147] XLA backend will use up to 3910582272 bytes on device 2 for CollectiveBFCAllocator.
I0000 00:00:1715804980.365911 3806453 gpu_helpers.cc:147] XLA backend will use up to 3910582272 bytes on device 3 for CollectiveBFCAllocator.
I0000 00:00:1715804980.366071 3806453 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1715804980.368051 3806453 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1715804980.370038 3806453 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1715804980.371979 3806453 cuda_executor.cc:1032] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
-------- 0 -----------------
Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis: mark_step when dynamo processing input graphs
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: 455b827bd004a389d858839fb4ef3944
Compilation Analysis: Number of Graph Inputs: 8
Compilation Analysis: Number of Graph Outputs: 8
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: mark_step (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1056)
Compilation Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:542)
Compilation Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51)
Compilation Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Compilation Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113)
Compilation Analysis: forward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:505)
Compilation Analysis: apply (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/autograd/function.py:598)
Compilation Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Compilation Analysis: ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: mark_step when dynamo processing input graphs
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: 455b827bd004a389d858839fb4ef3944
Execution Analysis: Number of Graph Inputs: 8
Execution Analysis: Number of Graph Outputs: 8
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: mark_step (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1056)
Execution Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:542)
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51)
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Execution Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113)
Execution Analysis: forward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:505)
Execution Analysis: apply (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/autograd/function.py:598)
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis: dynamo is compiling a FX graph to HLO
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: 74da15968ea4ebca0372e98ddca11641
Compilation Analysis: Number of Graph Inputs: 8
Compilation Analysis: Number of Graph Outputs: 8
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: extract_graph_helper (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:339)
Compilation Analysis: extract_internal (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:374)
Compilation Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:621)
Compilation Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51)
Compilation Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Compilation Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113)
Compilation Analysis: forward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:505)
Compilation Analysis: apply (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/autograd/function.py:598)
Compilation Analysis: ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: dynamo is executing a compiled program
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: 74da15968ea4ebca0372e98ddca11641
Execution Analysis: Number of Graph Inputs: 8
Execution Analysis: Number of Graph Outputs: 8
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427)
Execution Analysis: forward (<eval_with_key>.1:5)
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304)
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737)
Execution Analysis: _lazy_forward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py:123)
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis: mark_step when dynamo processing input graphs
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: 8e6aa44c425d7d283e8e959f91190d1a
Compilation Analysis: Number of Graph Inputs: 0
Compilation Analysis: Number of Graph Outputs: 1
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: mark_step (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1056)
Compilation Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:542)
Compilation Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51)
Compilation Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Compilation Analysis: inner (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:36)
Compilation Analysis: _fn (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451)
Compilation Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113)
Compilation Analysis: call_compiled_backward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:831)
Compilation Analysis: ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: mark_step when dynamo processing input graphs
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: 8e6aa44c425d7d283e8e959f91190d1a
Execution Analysis: Number of Graph Inputs: 0
Execution Analysis: Number of Graph Outputs: 1
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: mark_step (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/xla_model.py:1056)
Execution Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:542)
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51)
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Execution Analysis: inner (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:36)
Execution Analysis: _fn (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451)
Execution Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113)
Execution Analysis: call_compiled_backward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:831)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis: dynamo is compiling a FX graph to HLO
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: a9a891709155550a12b9907720b7d210
Compilation Analysis: Number of Graph Inputs: 8
Compilation Analysis: Number of Graph Outputs: 6
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: extract_graph_helper (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:339)
Compilation Analysis: extract_internal (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:374)
Compilation Analysis: extract_compiled_graph (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:621)
Compilation Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:51)
Compilation Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Compilation Analysis: inner (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:36)
Compilation Analysis: _fn (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451)
Compilation Analysis: call_func_at_runtime_with_args (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:113)
Compilation Analysis: ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: dynamo is executing a compiled program
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: a9a891709155550a12b9907720b7d210
Execution Analysis: Number of Graph Inputs: 8
Execution Analysis: Number of Graph Outputs: 6
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427)
Execution Analysis: forward (<eval_with_key>.3:5)
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304)
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737)
Execution Analysis: _lazy_forward (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py:123)
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
-------- 1 -----------------
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: dynamo is executing a compiled program
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: 74da15968ea4ebca0372e98ddca11641
Execution Analysis: Number of Graph Inputs: 8
Execution Analysis: Number of Graph Outputs: 8
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427)
Execution Analysis: forward (<eval_with_key>.1:5)
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304)
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737)
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53)
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: most likely user code trying to access tensor value before mark_step
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: 8e6aa44c425d7d283e8e959f91190d1a
Execution Analysis: Number of Graph Inputs: 0
Execution Analysis: Number of Graph Outputs: 1
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:398)
Execution Analysis: forward (<eval_with_key>.3:5)
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304)
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737)
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53)
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: dynamo is executing a compiled program
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: a9a891709155550a12b9907720b7d210
Execution Analysis: Number of Graph Inputs: 8
Execution Analysis: Number of Graph Outputs: 6
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427)
Execution Analysis: forward (<eval_with_key>.3:5)
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304)
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737)
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53)
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
-------- 2 -----------------
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: dynamo is executing a compiled program
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: 74da15968ea4ebca0372e98ddca11641
Execution Analysis: Number of Graph Inputs: 8
Execution Analysis: Number of Graph Outputs: 8
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427)
Execution Analysis: forward (<eval_with_key>.1:5)
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304)
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737)
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53)
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: most likely user code trying to access tensor value before mark_step
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: 8e6aa44c425d7d283e8e959f91190d1a
Execution Analysis: Number of Graph Inputs: 0
Execution Analysis: Number of Graph Outputs: 1
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:398)
Execution Analysis: forward (<eval_with_key>.3:5)
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304)
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737)
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53)
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
Execution Analysis: ================================================================================
Execution Analysis: Execution Cause
Execution Analysis: dynamo is executing a compiled program
Execution Analysis: Graph Info:
Execution Analysis: Graph Hash: a9a891709155550a12b9907720b7d210
Execution Analysis: Number of Graph Inputs: 8
Execution Analysis: Number of Graph Outputs: 6
Execution Analysis: Python Frame Triggered Execution:
Execution Analysis: optimized_mod (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:427)
Execution Analysis: forward (<eval_with_key>.3:5)
Execution Analysis: _call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1541)
Execution Analysis: _wrapped_call_impl (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/nn/modules/module.py:1532)
Execution Analysis: __call__ (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:304)
Execution Analysis: call_wrapped (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/fx/graph_module.py:737)
Execution Analysis: fwd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:53)
Execution Analysis: g (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:89)
Execution Analysis: ..........
Execution Analysis: --------------------------------------------------------------------------------
Execution Analysis: ================================================================================
Done!
WARNING:pt-xla-profiler:================================================================================
WARNING:pt-xla-profiler:Unlowered Op usage summary (more of these ops, lower performance)
WARNING:pt-xla-profiler:Note: _local_scalar_dense typically indicates CPU context access
WARNING:pt-xla-profiler:--------------------------------------------------------------------------------
WARNING:pt-xla-profiler:FRAME (count=3):
WARNING:pt-xla-profiler:Unlowered Op: "xla_cpu_fallback"
WARNING:pt-xla-profiler:
WARNING:pt-xla-profiler:
WARNING:pt-xla-profiler:FRAME (count=3):
WARNING:pt-xla-profiler: _multi_tensor_sgd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/optim/sgd.py:360)
WARNING:pt-xla-profiler: sgd (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/optim/sgd.py:245)
WARNING:pt-xla-profiler: step (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/optim/sgd.py:80)
WARNING:pt-xla-profiler: _use_grad (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/optim/optimizer.py:76)
WARNING:pt-xla-profiler: wrapper (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/optim/optimizer.py:391)
WARNING:pt-xla-profiler: _fn (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451)
WARNING:pt-xla-profiler: torch_dynamo_resume_in_train_loop_at_43 (/home/bbahl/test_dynamo.py:44)
WARNING:pt-xla-profiler: torch_dynamo_resume_in_train_loop_at_39 (/home/bbahl/test_dynamo.py:43)
WARNING:pt-xla-profiler: train_loop (/home/bbahl/test_dynamo.py:39)
WARNING:pt-xla-profiler: _fn (/opt/conda/envs/jaxnew/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:451)
WARNING:pt-xla-profiler: <module> (/home/bbahl/test_dynamo.py:55)
WARNING:pt-xla-profiler:
WARNING:pt-xla-profiler:
WARNING:pt-xla-profiler:================================================================================
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch_xla.core.xla_model as xm
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=1024)
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
device = xm.xla_device()
model = NeuralNetwork().to(device)
def train_loop(X, y, model, optimizer):
optimizer.zero_grad()
loss_fn = nn.CrossEntropyLoss()
pred = model(X)
loss = loss_fn(pred, y)
loss.backward()
optimizer.step()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, foreach=True)
optimizer.step = torch._dynamo.disable(optimizer.step)
train_loop_ = torch.compile(train_loop, backend='openxla')
for batch, (X, y) in enumerate(train_dataloader):
print(f"-------- {batch} -----------------")
X = X.to(device)
y = y.to(device)
train_loop_(X, y, model, optimizer)
if batch == 2:
break
print("Done!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment