Skip to content

Instantly share code, notes, and snippets.

@nilreml
Last active March 24, 2024 15:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nilreml/c88621af7295c0d9af8ae9599da49c15 to your computer and use it in GitHub Desktop.
Save nilreml/c88621af7295c0d9af8ae9599da49c15 to your computer and use it in GitHub Desktop.
# Monkeypatch pytorch to run kernel autotuning on non-enterprise CUDA GPUs
# until https://github.com/pytorch/pytorch/issues/109489 is resolved:
import sys
def uncache(exclude):
"""Remove package modules from cache except excluded ones.
On next import they will be reloaded.
Args:
exclude (iter<str>): Sequence of module paths.
"""
pkgs = []
for mod in exclude:
pkg = mod.split(".", 1)[0]
pkgs.append(pkg)
to_uncache = []
for mod in sys.modules:
if mod in exclude:
continue
if mod in pkgs:
to_uncache.append(mod)
continue
for pkg in pkgs:
if mod.startswith(pkg + "."):
to_uncache.append(mod)
break
for mod in to_uncache:
del sys.modules[mod]
# These lines must appear before 'import torch' in the main training/inference script:
from torch._inductor import utils
utils._use_template_for_cuda = lambda x, y: True
uncache([])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment