Last active
August 4, 2023 14:10
-
-
Save cmdr2/92178748f1cb77ea01d698a3008bc5fa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
import platform | |
from importlib.metadata import version as pkg_version | |
from sdkit.utils import log | |
from easydiffusion import app | |
# future home of scripts/check_modules.py | |
def get_trt_windows_install_commands(): | |
trt_dir = os.path.join(app.ROOT_DIR, "tensorrt") | |
if os.path.exists(trt_dir) and os.path.isdir(trt_dir) and len(os.listdir(trt_dir)) > 0: | |
files = os.listdir(trt_dir) | |
packages = ["nvidia-cudnn", "tensorrt-libs", "tensorrt"] | |
packages = tuple(p.replace("-", "_") for p in packages) | |
wheels = [] | |
for p in packages: | |
p = p.split(" ")[0] | |
f = next((f for f in files if f.startswith(p) and f.endswith((".whl", ".tar.gz"))), None) | |
if f: | |
wheels.append(os.path.join(trt_dir, f)) | |
wheels += [ | |
"protobuf==3.20.3 polygraphy==0.47.1 onnx==1.14.0 --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com" | |
] | |
return wheels | |
else: | |
import torch | |
cuda_version = torch.version.cuda | |
if cuda_version.startswith("11."): | |
return [ | |
"nvidia-cudnn-cu11 --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
"tensorrt-libs --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
"tensorrt==9.0.0.post11.dev1 --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
"protobuf==3.20.3 polygraphy==0.47.1 onnx==1.14.0 --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
] | |
elif cuda_version.startswith("12."): | |
return [ | |
"nvidia-cudnn-cu12 --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
"tensorrt-libs --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
"tensorrt --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
"protobuf==3.20.3 polygraphy==0.47.1 onnx==1.14.0 --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
] | |
manifest = { | |
"tensorrt": { | |
"install": { | |
"windows": get_trt_windows_install_commands, | |
"linux": [ | |
"nvidia-cudnn --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
"tensorrt-libs --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
"tensorrt --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
"protobuf==3.20.3 polygraphy==0.47.1 onnx==1.14.0 --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com", | |
], | |
}, | |
"uninstall": { | |
"windows": ["tensorrt"], | |
"linux": ["tensorrt"], | |
}, | |
# TODO also uninstall tensorrt-libs and nvidia-cudnn, but do it upon restarting (avoid 'file in use' error) | |
} | |
} | |
installing = [] | |
def get_installed_packages() -> list: | |
return {module_name: version(module_name) for module_name in manifest if is_installed(module_name)} | |
def is_installed(module_name) -> bool: | |
return version(module_name) is not None | |
def install(module_name): | |
os_name = platform.system().lower() | |
if is_installed(module_name): | |
log.info(f"{module_name} has already been installed!") | |
return | |
if module_name in installing: | |
log.info(f"{module_name} is already installing!") | |
return | |
if module_name not in manifest: | |
raise RuntimeError(f"Can't install unknown package: {module_name}!") | |
commands = manifest[module_name]["install"][os_name] | |
if callable(commands): | |
commands = commands() | |
commands = [f"python -m pip install --upgrade {cmd}" for cmd in commands] | |
installing.append(module_name) | |
try: | |
for cmd in commands: | |
print(">", cmd) | |
if os.system(cmd) != 0: | |
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.") | |
finally: | |
installing.remove(module_name) | |
def uninstall(module_name): | |
os_name = platform.system().lower() | |
if not is_installed(module_name): | |
log.info(f"{module_name} hasn't been installed!") | |
return | |
if module_name not in manifest: | |
raise RuntimeError(f"Can't uninstall unknown package: {module_name}!") | |
commands = manifest[module_name]["uninstall"][os_name] | |
if callable(commands): | |
commands = commands() | |
commands = [f"python -m pip uninstall -y {cmd}" for cmd in commands] | |
for cmd in commands: | |
print(">", cmd) | |
if os.system(cmd) != 0: | |
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.") | |
def version(module_name: str) -> str: | |
try: | |
return pkg_version(module_name) | |
except: | |
return None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment