My environment: Ubuntu 22.04 LTS on Win11 WSL 2, RTX 2060 on Laptop
I have checked nvidia and cuda already is latest by running below command to see installed version, and comparing the result of googling latest nvdia driver and cuda version online
nvidia-smi
And i found out what was the problem
when i tried to pip upgrade all the dependencies, it installed cpu version of pytorch...
you can verify it by
python -m torch.utils.collect_env
I got the result, notice all the cpu in the package description
[pip3] numpy==1.24.3
[pip3] pytorch-lightning==1.7.7
[pip3] torch==1.12.1
[pip3] torchaudio==0.12.1
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==0.11.4
[pip3] torchsde==0.2.5
[pip3] torchvision==0.15.1a0+60a3e72
[conda] cudatoolkit 10.2.89 h713d32c_11 conda-forge
[conda] numpy 1.24.3 py310ha4c1d20_0 conda-forge
[conda] pytorch 1.12.1 cpu_py310h9dbd814_1
[conda] pytorch-lightning 1.7.7 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchaudio 0.12.1 py310_cu102 pytorch
[conda] torchdiffeq 0.2.3 pypi_0 pypi
[conda] torchmetrics 0.11.4 pypi_0 pypi
[conda] torchsde 0.2.5 pypi_0 pypi
[conda] torchvision 0.15.1 cpu_py310h0397dac_0 conda-forge
Then i remove pytorch packages
pip uninstall torch torchvision torchaudio
and install the cuda version as suggested by official doc
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
and then run the package check command again, we can notice now the package are cuda version
[pip3] numpy==1.24.3
[pip3] pytorch-lightning==1.7.7
[pip3] torch==2.0.1+cu118
[pip3] torchaudio==2.0.2+cu118
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==0.11.4
[pip3] torchsde==0.2.5
[pip3] torchvision==0.15.2+cu118
[conda] cudatoolkit 10.2.89 h713d32c_11 conda-forge
[conda] numpy 1.24.3 py310ha4c1d20_0 conda-forge
[conda] pytorch-lightning 1.7.7 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 2.0.1+cu118 pypi_0 pypi
[conda] torchaudio 2.0.2+cu118 pypi_0 pypi
[conda] torchdiffeq 0.2.3 pypi_0 pypi
[conda] torchmetrics 0.11.4 pypi_0 pypi
[conda] torchsde 0.2.5 pypi_0 pypi
[conda] torchvision 0.15.2+cu118 pypi_0 pypi