Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save georgevreilly/d6ba5cbfa5b0817a994be801a02063ee to your computer and use it in GitHub Desktop.
Save georgevreilly/d6ba5cbfa5b0817a994be801a02063ee to your computer and use it in GitHub Desktop.
CUDA deps cannot be preloaded under Bazel

CUDA deps cannot be preloaded under Bazel

🐛 Describe the bug

If Torch 2.1.0 is used as a dependency with Bazel and rules_python, _preload_cuda_deps fails with OSError: libcufft.so.11: cannot open shared object file: No such file or directory.

Torch 2.1 works fine if you install it and its CUDA dependencies into a single site-packages (e.g., in a virtualenv). It doesn't work with Bazel, as it installs each dependency into its own directory tree, which is appended to PYTHONPATH.

$ bazel test //...
Starting local Bazel server and connecting to it...
INFO: Analyzed 3 targets (66 packages loaded, 15423 targets configured).
INFO: Found 2 targets and 1 test target...
FAIL: //calculator:calc_test (see /pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/bazel-out/k8-fastbuild/testlogs/calculator/calc_test/test.log)
INFO: From Testing //calculator:calc_test:
==================== Test output for //calculator:calc_test:
sys.path=['/pay/src/torch21/calculator',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_filelock/site-packages',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_fsspec/site-packages',
... [40 directories omitted] ...
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_sympy',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_torch',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_triton',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_typing_extensions',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python38.zip',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python3.8',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python3.8/lib-dynload',
    '/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python3.8/site-packages']
Traceback (most recent call last):
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_torch/site-packages/torch/__init__.py", line 174, in _load_global_deps
    ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python3.8/ctypes/__init__.py", line 373, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: libcufft.so.11: cannot open shared object file: No such file or directory

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/__main__/calculator/calc_test.py", line 10, in <module>
    import torch  # type: ignore
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_torch/site-packages/torch/__init__.py", line 234, in <module>
    _load_global_deps()
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_torch/site-packages/torch/__init__.py", line 195, in _load_global_deps
    _preload_cuda_deps(lib_folder, lib_name)
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/sandbox/linux-sandbox/1/execroot/__main__/bazel-out/k8-fastbuild/bin/calculator/calc_test.runfiles/python_deps_torch/site-packages/torch/__init__.py", line 161, in _preload_cuda_deps
    ctypes.CDLL(lib_path)
  File "/pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/external/python_3_8_x86_64-unknown-linux-gnu/lib/python3.8/ctypes/__init__.py", line 373, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: libnvJitLink.so.12: cannot open shared object file: No such file or directory
================================================================================
INFO: Elapsed time: 131.313s, Critical Path: 6.78s
INFO: 5 processes: 3 internal, 2 linux-sandbox.
INFO: Build completed, 1 test FAILED, 5 total actions
//calculator:calc_test                                                   FAILED in 0.9s
  /pay/home/georgevreilly/.cache/bazel/_bazel_georgevreilly/b060158845e808ff2a9c2fcf0dcfee37/execroot/__main__/bazel-out/k8-fastbuild/testlogs/calculator/calc_test/test.log

Executed 1 out of 1 test: 1 fails locally.

This can be fixed by slightly reordering cuda_libs in _load_global_deps so that they are topologically sorted.

diff --git a/torch/__init__.py b/torch/__init__.py
index 98c9a43511c..bad6a5f6c3d 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -178,7 +178,11 @@ def _load_global_deps() -> None:
     except OSError as err:
         # Can only happen for wheel with cuda libs as PYPI deps
         # As PyTorch is not purelib, but nvidia-*-cu12 is
+        # These dependencies have been topologically sorted,
+        # so that a lib is loaded after all of its dependencies.
         cuda_libs: Dict[str, str] = {
+            'nvjitlink': 'libnvJitLink.so.*[0-9]',
+            'cusparse': 'libcusparse.so.*[0-9]',
             'cublas': 'libcublas.so.*[0-9]',
             'cudnn': 'libcudnn.so.*[0-9]',
             'cuda_nvrtc': 'libnvrtc.so.*[0-9]',
@@ -187,7 +191,6 @@ def _load_global_deps() -> None:
             'cufft': 'libcufft.so.*[0-9]',
             'curand': 'libcurand.so.*[0-9]',
             'cusolver': 'libcusolver.so.*[0-9]',
-            'cusparse': 'libcusparse.so.*[0-9]',
             'nccl': 'libnccl.so.*[0-9]',
             'nvtx': 'libnvToolsExt.so.*[0-9]',
         }

I have a full repro of the problem, which has a tiny Python app that works in a regular virtualenv, but fails with Bazel. I also created a tool there that patches the Torch wheel. The patched wheel works for us.

Related Issues

Versions

$ python collect_env.py
/pay/tmp/venv-torch21/lib/python3.8/site-packages/torch/nn/modules/transformer.py:20: UserWarning: Failed to initialize NumPy: numpy.core.multiarray failed to import (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),
Collecting environment information...
/pay/tmp/venv-torch21/lib/python3.8/site-packages/torch/cuda/__init__.py:138: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11060). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
  return torch._C._cuda_getDeviceCount() > 0

cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             8
On-line CPU(s) list:                0-7
Thread(s) per core:                 2
Core(s) per socket:                 4
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              79
Model name:                         Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
Stepping:                           1
CPU MHz:                            3000.000
CPU max MHz:                        3000.0000
CPU min MHz:                        1200.0000
BogoMIPS:                           4600.02
Hypervisor vendor:                  Xen
Virtualization type:                full
L1d cache:                          128 KiB
L1i cache:                          128 KiB
L2 cache:                           1 MiB
L3 cache:                           45 MiB
NUMA node0 CPU(s):                  0-7
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        KVM: Mitigation: VMX unsupported
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx xsaveopt

Versions of relevant libraries:
[pip3] torch==2.1.0
[pip3] triton==2.1.0
[conda] Could not collect
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment