Last active
March 24, 2024 15:00
-
-
Save nilreml/c88621af7295c0d9af8ae9599da49c15 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Monkeypatch pytorch to run kernel autotuning on non-enterprise CUDA GPUs | |
# until https://github.com/pytorch/pytorch/issues/109489 is resolved: | |
import sys | |
def uncache(exclude): | |
"""Remove package modules from cache except excluded ones. | |
On next import they will be reloaded. | |
Args: | |
exclude (iter<str>): Sequence of module paths. | |
""" | |
pkgs = [] | |
for mod in exclude: | |
pkg = mod.split(".", 1)[0] | |
pkgs.append(pkg) | |
to_uncache = [] | |
for mod in sys.modules: | |
if mod in exclude: | |
continue | |
if mod in pkgs: | |
to_uncache.append(mod) | |
continue | |
for pkg in pkgs: | |
if mod.startswith(pkg + "."): | |
to_uncache.append(mod) | |
break | |
for mod in to_uncache: | |
del sys.modules[mod] | |
# These lines must appear before 'import torch' in the main training/inference script: | |
from torch._inductor import utils | |
utils._use_template_for_cuda = lambda x, y: True | |
uncache([]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment