Skip to content

Instantly share code, notes, and snippets.

@cmdr2
Last active August 4, 2023 14:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cmdr2/92178748f1cb77ea01d698a3008bc5fa to your computer and use it in GitHub Desktop.
Save cmdr2/92178748f1cb77ea01d698a3008bc5fa to your computer and use it in GitHub Desktop.
import sys
import os
import platform
from importlib.metadata import version as pkg_version
from sdkit.utils import log
from easydiffusion import app
# future home of scripts/check_modules.py
def get_trt_windows_install_commands():
trt_dir = os.path.join(app.ROOT_DIR, "tensorrt")
if os.path.exists(trt_dir) and os.path.isdir(trt_dir) and len(os.listdir(trt_dir)) > 0:
files = os.listdir(trt_dir)
packages = ["nvidia-cudnn", "tensorrt-libs", "tensorrt"]
packages = tuple(p.replace("-", "_") for p in packages)
wheels = []
for p in packages:
p = p.split(" ")[0]
f = next((f for f in files if f.startswith(p) and f.endswith((".whl", ".tar.gz"))), None)
if f:
wheels.append(os.path.join(trt_dir, f))
wheels += [
"protobuf==3.20.3 polygraphy==0.47.1 onnx==1.14.0 --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com"
]
return wheels
else:
import torch
cuda_version = torch.version.cuda
if cuda_version.startswith("11."):
return [
"nvidia-cudnn-cu11 --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"tensorrt-libs --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"tensorrt==9.0.0.post11.dev1 --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"protobuf==3.20.3 polygraphy==0.47.1 onnx==1.14.0 --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
]
elif cuda_version.startswith("12."):
return [
"nvidia-cudnn-cu12 --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"tensorrt-libs --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"tensorrt --pre --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"protobuf==3.20.3 polygraphy==0.47.1 onnx==1.14.0 --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
]
manifest = {
"tensorrt": {
"install": {
"windows": get_trt_windows_install_commands,
"linux": [
"nvidia-cudnn --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"tensorrt-libs --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"tensorrt --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"protobuf==3.20.3 polygraphy==0.47.1 onnx==1.14.0 --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
],
},
"uninstall": {
"windows": ["tensorrt"],
"linux": ["tensorrt"],
},
# TODO also uninstall tensorrt-libs and nvidia-cudnn, but do it upon restarting (avoid 'file in use' error)
}
}
installing = []
def get_installed_packages() -> list:
return {module_name: version(module_name) for module_name in manifest if is_installed(module_name)}
def is_installed(module_name) -> bool:
return version(module_name) is not None
def install(module_name):
os_name = platform.system().lower()
if is_installed(module_name):
log.info(f"{module_name} has already been installed!")
return
if module_name in installing:
log.info(f"{module_name} is already installing!")
return
if module_name not in manifest:
raise RuntimeError(f"Can't install unknown package: {module_name}!")
commands = manifest[module_name]["install"][os_name]
if callable(commands):
commands = commands()
commands = [f"python -m pip install --upgrade {cmd}" for cmd in commands]
installing.append(module_name)
try:
for cmd in commands:
print(">", cmd)
if os.system(cmd) != 0:
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.")
finally:
installing.remove(module_name)
def uninstall(module_name):
os_name = platform.system().lower()
if not is_installed(module_name):
log.info(f"{module_name} hasn't been installed!")
return
if module_name not in manifest:
raise RuntimeError(f"Can't uninstall unknown package: {module_name}!")
commands = manifest[module_name]["uninstall"][os_name]
if callable(commands):
commands = commands()
commands = [f"python -m pip uninstall -y {cmd}" for cmd in commands]
for cmd in commands:
print(">", cmd)
if os.system(cmd) != 0:
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.")
def version(module_name: str) -> str:
try:
return pkg_version(module_name)
except:
return None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment