Skip to content

Instantly share code, notes, and snippets.

@sandeshrajbhandari
Created March 18, 2024 03:27
Show Gist options
  • Save sandeshrajbhandari/ac3857cd2aaae5e3a9de0d7c219ac351 to your computer and use it in GitHub Desktop.
Save sandeshrajbhandari/ac3857cd2aaae5e3a9de0d7c219ac351 to your computer and use it in GitHub Desktop.
Open sora inference notebook in colab t4 - not running yet.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"T9aKqAQYjTAV"
],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"source": [
"!python -V"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0GacQ__2XKxf",
"outputId": "81bef578-326d-4b95-a930-0bac028cbe6f"
},
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Python 3.10.12\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sE8MMGbDMHU6"
},
"outputs": [],
"source": [
"%%shell\n",
"# create a virtual env\n",
"\n",
"# install torch\n",
"# the command below is for CUDA 12.1, choose install commands from\n",
"# https://pytorch.org/get-started/locally/ based on your own CUDA version\n",
"pip install torch torchvision\n",
"\n",
"# install flash attention (optional)\n",
"pip install packaging ninja\n",
"pip install flash-attn --no-build-isolation\n",
"\n",
"# install apex (optional)\n",
"# pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" git+https://github.com/NVIDIA/apex.git\n",
"## install later after cloning repo and editing runtime error.\n",
"\n",
"# install xformers\n",
"pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121\n",
"\n",
"# install this project\n",
"git clone https://github.com/hpcaitech/Open-Sora\n",
"cd Open-Sora\n",
"pip install -v ."
]
},
{
"cell_type": "markdown",
"source": [
"## install apex, troubleshoot cuda compile error"
],
"metadata": {
"id": "Ee2hzjNyZPJu"
}
},
{
"cell_type": "code",
"source": [
"%cd ..\n",
"!git clone https://github.com/NVIDIA/apex"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Vx_h2skfYvNk",
"outputId": "6a91b5f2-8055-4f25-8310-0ef7cf2d4f47"
},
"execution_count": 35,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content\n",
"Cloning into 'apex'...\n",
"remote: Enumerating objects: 11638, done.\u001b[K\n",
"remote: Counting objects: 100% (3706/3706), done.\u001b[K\n",
"remote: Compressing objects: 100% (569/569), done.\u001b[K\n",
"remote: Total 11638 (delta 3342), reused 3264 (delta 3134), pack-reused 7932\u001b[K\n",
"Receiving objects: 100% (11638/11638), 15.47 MiB | 18.72 MiB/s, done.\n",
"Resolving deltas: 100% (8171/8171), done.\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### setup.py code as setup_py variable string"
],
"metadata": {
"id": "T9aKqAQYjTAV"
}
},
{
"cell_type": "code",
"source": [
"setup_py = \"\"\"\n",
"import sys\n",
"import warnings\n",
"import os\n",
"import glob\n",
"from packaging.version import parse, Version\n",
"\n",
"from setuptools import setup, find_packages\n",
"import subprocess\n",
"\n",
"import torch\n",
"from torch.utils.cpp_extension import (\n",
" BuildExtension,\n",
" CppExtension,\n",
" CUDAExtension,\n",
" CUDA_HOME,\n",
" load,\n",
")\n",
"\n",
"# ninja build does not work unless include_dirs are abs path\n",
"this_dir = os.path.dirname(os.path.abspath(__file__))\n",
"\n",
"\n",
"def get_cuda_bare_metal_version(cuda_dir):\n",
" raw_output = subprocess.check_output([cuda_dir + \"/bin/nvcc\", \"-V\"], universal_newlines=True)\n",
" output = raw_output.split()\n",
" release_idx = output.index(\"release\") + 1\n",
" bare_metal_version = parse(output[release_idx].split(\",\")[0])\n",
"\n",
" return raw_output, bare_metal_version\n",
"\n",
"\n",
"def check_cuda_torch_binary_vs_bare_metal(cuda_dir):\n",
" raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)\n",
" torch_binary_version = parse(torch.version.cuda)\n",
"\n",
" print(\"\\nCompiling cuda extensions with\")\n",
" print(raw_output + \"from \" + cuda_dir + \"/bin\\n\")\n",
"\n",
" if (bare_metal_version != torch_binary_version):\n",
" print(\n",
" \"Cuda extensions are being compiled with a version of Cuda that does \"\n",
" \"not match the version used to compile Pytorch binaries. \"\n",
" \"Pytorch binaries were compiled with Cuda {}.\\n\".format(torch.version.cuda)\n",
" + \"In some cases, a minor-version mismatch will not cause later errors: \"\n",
" \"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. \"\n",
" \"You can try commenting out this check (at your own risk).\"\n",
" )\n",
"\n",
"\n",
"def raise_if_cuda_home_none(global_option: str) -> None:\n",
" if CUDA_HOME is not None:\n",
" return\n",
" raise RuntimeError(\n",
" f\"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? \"\n",
" \"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, \"\n",
" \"only images whose names contain 'devel' will provide nvcc.\"\n",
" )\n",
"\n",
"\n",
"def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:\n",
" cudnn_available = torch.backends.cudnn.is_available()\n",
" cudnn_version = torch.backends.cudnn.version() if cudnn_available else None\n",
" if not (cudnn_available and (cudnn_version >= required_cudnn_version)):\n",
" warnings.warn(\n",
" f\"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later, \"\n",
" f\"but {'cuDNN is not available' if not cudnn_available else cudnn_version}\"\n",
" )\n",
" return False\n",
" return True\n",
"\n",
"\n",
"if not torch.cuda.is_available():\n",
" # https://github.com/NVIDIA/apex/issues/486\n",
" # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),\n",
" # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).\n",
" print(\n",
" \"\\nWarning: Torch did not find available GPUs on this system.\\n\",\n",
" \"If your intention is to cross-compile, this is not an error.\\n\"\n",
" \"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\\n\"\n",
" \"Volta (compute capability 7.0), Turing (compute capability 7.5),\\n\"\n",
" \"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\\n\"\n",
" \"If you wish to cross-compile for a single specific architecture,\\n\"\n",
" 'export TORCH_CUDA_ARCH_LIST=\"compute capability\" before running setup.py.\\n',\n",
" )\n",
" if os.environ.get(\"TORCH_CUDA_ARCH_LIST\", None) is None and CUDA_HOME is not None:\n",
" _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\n",
" if bare_metal_version >= Version(\"11.8\"):\n",
" os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0\"\n",
" elif bare_metal_version >= Version(\"11.1\"):\n",
" os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5;8.0;8.6\"\n",
" elif bare_metal_version == Version(\"11.0\"):\n",
" os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5;8.0\"\n",
" else:\n",
" os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5\"\n",
"\n",
"print(\"\\n\\ntorch.__version__ = {}\\n\\n\".format(torch.__version__))\n",
"TORCH_MAJOR = int(torch.__version__.split(\".\")[0])\n",
"TORCH_MINOR = int(torch.__version__.split(\".\")[1])\n",
"\n",
"if TORCH_MAJOR == 0 and TORCH_MINOR < 4:\n",
" raise RuntimeError(\n",
" \"Apex requires Pytorch 0.4 or newer.\\nThe latest stable release can be obtained from https://pytorch.org/\"\n",
" )\n",
"\n",
"cmdclass = {}\n",
"ext_modules = []\n",
"\n",
"extras = {}\n",
"\n",
"if \"--cpp_ext\" in sys.argv or \"--cuda_ext\" in sys.argv:\n",
" if TORCH_MAJOR == 0:\n",
" raise RuntimeError(\n",
" \"--cpp_ext requires Pytorch 1.0 or later, \" \"found torch.__version__ = {}\".format(torch.__version__)\n",
" )\n",
"\n",
"if \"--cpp_ext\" in sys.argv:\n",
" sys.argv.remove(\"--cpp_ext\")\n",
" ext_modules.append(CppExtension(\"apex_C\", [\"csrc/flatten_unflatten.cpp\"]))\n",
"\n",
"\n",
"# Set up macros for forward/backward compatibility hack around\n",
"# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e\n",
"# and\n",
"# https://github.com/NVIDIA/apex/issues/456\n",
"# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac\n",
"version_ge_1_1 = []\n",
"if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):\n",
" version_ge_1_1 = [\"-DVERSION_GE_1_1\"]\n",
"version_ge_1_3 = []\n",
"if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):\n",
" version_ge_1_3 = [\"-DVERSION_GE_1_3\"]\n",
"version_ge_1_5 = []\n",
"if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):\n",
" version_ge_1_5 = [\"-DVERSION_GE_1_5\"]\n",
"version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5\n",
"\n",
"_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\n",
"\n",
"if \"--distributed_adam\" in sys.argv:\n",
" sys.argv.remove(\"--distributed_adam\")\n",
" raise_if_cuda_home_none(\"--distributed_adam\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"distributed_adam_cuda\",\n",
" sources=[\n",
" \"apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp\",\n",
" \"apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu\",\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\"-O3\", \"--use_fast_math\"] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
"if \"--distributed_lamb\" in sys.argv:\n",
" sys.argv.remove(\"--distributed_lamb\")\n",
" raise_if_cuda_home_none(\"--distributed_lamb\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"distributed_lamb_cuda\",\n",
" sources=[\n",
" \"apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp\",\n",
" \"apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu\",\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\"-O3\", \"--use_fast_math\"] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
"if \"--cuda_ext\" in sys.argv:\n",
" sys.argv.remove(\"--cuda_ext\")\n",
" raise_if_cuda_home_none(\"--cuda_ext\")\n",
" check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"amp_C\",\n",
" sources=[\n",
" \"csrc/amp_C_frontend.cpp\",\n",
" \"csrc/multi_tensor_sgd_kernel.cu\",\n",
" \"csrc/multi_tensor_scale_kernel.cu\",\n",
" \"csrc/multi_tensor_axpby_kernel.cu\",\n",
" \"csrc/multi_tensor_l2norm_kernel.cu\",\n",
" \"csrc/multi_tensor_l2norm_kernel_mp.cu\",\n",
" \"csrc/multi_tensor_l2norm_scale_kernel.cu\",\n",
" \"csrc/multi_tensor_lamb_stage_1.cu\",\n",
" \"csrc/multi_tensor_lamb_stage_2.cu\",\n",
" \"csrc/multi_tensor_adam.cu\",\n",
" \"csrc/multi_tensor_adagrad.cu\",\n",
" \"csrc/multi_tensor_novograd.cu\",\n",
" \"csrc/multi_tensor_lamb.cu\",\n",
" \"csrc/multi_tensor_lamb_mp.cu\",\n",
" \"csrc/update_scale_hysteresis.cu\",\n",
" ],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\n",
" \"-lineinfo\",\n",
" \"-O3\",\n",
" # '--resource-usage',\n",
" \"--use_fast_math\",\n",
" ] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"syncbn\",\n",
" sources=[\"csrc/syncbn.cpp\", \"csrc/welford.cu\"],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\"-O3\"] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fused_layer_norm_cuda\",\n",
" sources=[\"csrc/layer_norm_cuda.cpp\", \"csrc/layer_norm_cuda_kernel.cu\"],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\"-maxrregcount=50\", \"-O3\", \"--use_fast_math\"] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"mlp_cuda\",\n",
" sources=[\"csrc/mlp.cpp\", \"csrc/mlp_cuda.cu\"],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\"-O3\"] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fused_dense_cuda\",\n",
" sources=[\"csrc/fused_dense.cpp\", \"csrc/fused_dense_cuda.cu\"],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\"-O3\"] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"scaled_upper_triang_masked_softmax_cuda\",\n",
" sources=[\n",
" \"csrc/megatron/scaled_upper_triang_masked_softmax.cpp\",\n",
" \"csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu\",\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\n",
" \"-O3\",\n",
" \"-U__CUDA_NO_HALF_OPERATORS__\",\n",
" \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n",
" \"--expt-relaxed-constexpr\",\n",
" \"--expt-extended-lambda\",\n",
" ] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"generic_scaled_masked_softmax_cuda\",\n",
" sources=[\n",
" \"csrc/megatron/generic_scaled_masked_softmax.cpp\",\n",
" \"csrc/megatron/generic_scaled_masked_softmax_cuda.cu\",\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\n",
" \"-O3\",\n",
" \"-U__CUDA_NO_HALF_OPERATORS__\",\n",
" \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n",
" \"--expt-relaxed-constexpr\",\n",
" \"--expt-extended-lambda\",\n",
" ] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"scaled_masked_softmax_cuda\",\n",
" sources=[\"csrc/megatron/scaled_masked_softmax.cpp\", \"csrc/megatron/scaled_masked_softmax_cuda.cu\"],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\n",
" \"-O3\",\n",
" \"-U__CUDA_NO_HALF_OPERATORS__\",\n",
" \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n",
" \"--expt-relaxed-constexpr\",\n",
" \"--expt-extended-lambda\",\n",
" ] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"scaled_softmax_cuda\",\n",
" sources=[\"csrc/megatron/scaled_softmax.cpp\", \"csrc/megatron/scaled_softmax_cuda.cu\"],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\n",
" \"-O3\",\n",
" \"-U__CUDA_NO_HALF_OPERATORS__\",\n",
" \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n",
" \"--expt-relaxed-constexpr\",\n",
" \"--expt-extended-lambda\",\n",
" ] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fused_rotary_positional_embedding\",\n",
" sources=[\n",
" \"csrc/megatron/fused_rotary_positional_embedding.cpp\",\n",
" \"csrc/megatron/fused_rotary_positional_embedding_cuda.cu\",\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\n",
" \"-O3\",\n",
" \"-U__CUDA_NO_HALF_OPERATORS__\",\n",
" \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n",
" \"--expt-relaxed-constexpr\",\n",
" \"--expt-extended-lambda\",\n",
" ] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
" if bare_metal_version >= Version(\"11.0\"):\n",
"\n",
" cc_flag = []\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_70,code=sm_70\")\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_80,code=sm_80\")\n",
" if bare_metal_version >= Version(\"11.1\"):\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_86,code=sm_86\")\n",
" if bare_metal_version >= Version(\"11.8\"):\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_90,code=sm_90\")\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fused_weight_gradient_mlp_cuda\",\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" sources=[\n",
" \"csrc/megatron/fused_weight_gradient_dense.cpp\",\n",
" \"csrc/megatron/fused_weight_gradient_dense_cuda.cu\",\n",
" \"csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu\",\n",
" ],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\n",
" \"-O3\",\n",
" \"-U__CUDA_NO_HALF_OPERATORS__\",\n",
" \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n",
" \"--expt-relaxed-constexpr\",\n",
" \"--expt-extended-lambda\",\n",
" \"--use_fast_math\",\n",
" ] + version_dependent_macros + cc_flag,\n",
" },\n",
" )\n",
" )\n",
"\n",
"if \"--permutation_search\" in sys.argv:\n",
" sys.argv.remove(\"--permutation_search\")\n",
"\n",
" if CUDA_HOME is None:\n",
" raise RuntimeError(\"--permutation_search was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.\")\n",
" else:\n",
" cc_flag = ['-Xcompiler', '-fPIC', '-shared']\n",
" ext_modules.append(\n",
" CUDAExtension(name='permutation_search_cuda',\n",
" sources=['apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu'],\n",
" include_dirs=[os.path.join(this_dir, 'apex', 'contrib', 'sparsity', 'permutation_search_kernels', 'CUDA_kernels')],\n",
" extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,\n",
" 'nvcc':['-O3'] + version_dependent_macros + cc_flag}))\n",
"\n",
"if \"--bnp\" in sys.argv:\n",
" sys.argv.remove(\"--bnp\")\n",
" raise_if_cuda_home_none(\"--bnp\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"bnp\",\n",
" sources=[\n",
" \"apex/contrib/csrc/groupbn/batch_norm.cu\",\n",
" \"apex/contrib/csrc/groupbn/ipc.cu\",\n",
" \"apex/contrib/csrc/groupbn/interface.cpp\",\n",
" \"apex/contrib/csrc/groupbn/batch_norm_add_relu.cu\",\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [] + version_dependent_macros,\n",
" \"nvcc\": [\n",
" \"-DCUDA_HAS_FP16=1\",\n",
" \"-D__CUDA_NO_HALF_OPERATORS__\",\n",
" \"-D__CUDA_NO_HALF_CONVERSIONS__\",\n",
" \"-D__CUDA_NO_HALF2_OPERATORS__\",\n",
" ] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
"if \"--xentropy\" in sys.argv:\n",
" from datetime import datetime\n",
" sys.argv.remove(\"--xentropy\")\n",
" raise_if_cuda_home_none(\"--xentropy\")\n",
" xentropy_ver = datetime.today().strftime(\"%y.%m.%d\")\n",
" print(f\"`--xentropy` setting version of {xentropy_ver}\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"xentropy_cuda\",\n",
" sources=[\"apex/contrib/csrc/xentropy/interface.cpp\", \"apex/contrib/csrc/xentropy/xentropy_kernel.cu\"],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros + [f'-DXENTROPY_VER=\"{xentropy_ver}\"'],\n",
" \"nvcc\": [\"-O3\"] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
"if \"--focal_loss\" in sys.argv:\n",
" sys.argv.remove(\"--focal_loss\")\n",
" raise_if_cuda_home_none(\"--focal_loss\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name='focal_loss_cuda',\n",
" sources=[\n",
" 'apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp',\n",
" 'apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu',\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, 'csrc')],\n",
" extra_compile_args={\n",
" 'cxx': ['-O3'] + version_dependent_macros,\n",
" 'nvcc':['-O3', '--use_fast_math', '--ftz=false'] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
"if \"--group_norm\" in sys.argv:\n",
" sys.argv.remove(\"--group_norm\")\n",
" raise_if_cuda_home_none(\"--group_norm\")\n",
"\n",
" # CUDA group norm supports from SM70\n",
" arch_flags = []\n",
" for arch in [70, 75, 80, 86, 90]:\n",
" arch_flag = f\"-gencode=arch=compute_{arch},code=sm_{arch}\"\n",
" arch_flags.append(arch_flag)\n",
" arch_flag = f\"-gencode=arch=compute_90,code=compute_90\"\n",
" arch_flags.append(arch_flag)\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"group_norm_cuda\",\n",
" sources=[\n",
" \"apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp\",\n",
" ] + glob.glob(\"apex/contrib/csrc/group_norm/*.cu\"),\n",
" include_dirs=[os.path.join(this_dir, 'csrc')],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\", \"-std=c++17\"] + version_dependent_macros,\n",
" \"nvcc\": [\n",
" \"-O3\", \"-std=c++17\", \"--use_fast_math\", \"--ftz=false\",\n",
" ] + arch_flags + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
"if \"--index_mul_2d\" in sys.argv:\n",
" sys.argv.remove(\"--index_mul_2d\")\n",
" raise_if_cuda_home_none(\"--index_mul_2d\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name='fused_index_mul_2d',\n",
" sources=[\n",
" 'apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp',\n",
" 'apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu',\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, 'csrc')],\n",
" extra_compile_args={\n",
" 'cxx': ['-O3'] + version_dependent_macros,\n",
" 'nvcc':['-O3', '--use_fast_math', '--ftz=false'] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
"if \"--deprecated_fused_adam\" in sys.argv:\n",
" sys.argv.remove(\"--deprecated_fused_adam\")\n",
" raise_if_cuda_home_none(\"--deprecated_fused_adam\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fused_adam_cuda\",\n",
" sources=[\n",
" \"apex/contrib/csrc/optimizers/fused_adam_cuda.cpp\",\n",
" \"apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu\",\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\"-O3\", \"--use_fast_math\"] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
"if \"--deprecated_fused_lamb\" in sys.argv:\n",
" sys.argv.remove(\"--deprecated_fused_lamb\")\n",
" raise_if_cuda_home_none(\"--deprecated_fused_lamb\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fused_lamb_cuda\",\n",
" sources=[\n",
" \"apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp\",\n",
" \"apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu\",\n",
" \"csrc/multi_tensor_l2norm_kernel.cu\",\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\"-O3\", \"--use_fast_math\"] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
"# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h\n",
"# See https://github.com/pytorch/pytorch/pull/70650\n",
"generator_flag = []\n",
"torch_dir = torch.__path__[0]\n",
"if os.path.exists(os.path.join(torch_dir, \"include\", \"ATen\", \"CUDAGeneratorImpl.h\")):\n",
" generator_flag = [\"-DOLD_GENERATOR_PATH\"]\n",
"\n",
"if \"--fast_layer_norm\" in sys.argv:\n",
" sys.argv.remove(\"--fast_layer_norm\")\n",
" raise_if_cuda_home_none(\"--fast_layer_norm\")\n",
"\n",
" cc_flag = []\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_70,code=sm_70\")\n",
"\n",
" if bare_metal_version >= Version(\"11.0\"):\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_80,code=sm_80\")\n",
" if bare_metal_version >= Version(\"11.8\"):\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_90,code=sm_90\")\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fast_layer_norm\",\n",
" sources=[\n",
" \"apex/contrib/csrc/layer_norm/ln_api.cpp\",\n",
" \"apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu\",\n",
" \"apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu\",\n",
" ],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros + generator_flag,\n",
" \"nvcc\": [\n",
" \"-O3\",\n",
" \"-U__CUDA_NO_HALF_OPERATORS__\",\n",
" \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n",
" \"-U__CUDA_NO_BFLOAT16_OPERATORS__\",\n",
" \"-U__CUDA_NO_BFLOAT16_CONVERSIONS__\",\n",
" \"-U__CUDA_NO_BFLOAT162_OPERATORS__\",\n",
" \"-U__CUDA_NO_BFLOAT162_CONVERSIONS__\",\n",
" \"-I./apex/contrib/csrc/layer_norm/\",\n",
" \"--expt-relaxed-constexpr\",\n",
" \"--expt-extended-lambda\",\n",
" \"--use_fast_math\",\n",
" ] + version_dependent_macros + generator_flag + cc_flag,\n",
" },\n",
" include_dirs=[os.path.join(this_dir, \"apex/contrib/csrc/layer_norm\")],\n",
" )\n",
" )\n",
"\n",
"if \"--fmha\" in sys.argv:\n",
" sys.argv.remove(\"--fmha\")\n",
" raise_if_cuda_home_none(\"--fmha\")\n",
"\n",
" if bare_metal_version < Version(\"11.0\"):\n",
" raise RuntimeError(\"--fmha only supported on sm_80 and sm_90 GPUs\")\n",
"\n",
" cc_flag = []\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_80,code=sm_80\")\n",
" if bare_metal_version >= Version(\"11.8\"):\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_90,code=sm_90\")\n",
"\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fmhalib\",\n",
" sources=[\n",
" \"apex/contrib/csrc/fmha/fmha_api.cpp\",\n",
" \"apex/contrib/csrc/fmha/src/fmha_fill.cu\",\n",
" \"apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu\",\n",
" \"apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu\",\n",
" \"apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu\",\n",
" \"apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu\",\n",
" \"apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu\",\n",
" \"apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu\",\n",
" \"apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu\",\n",
" \"apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu\",\n",
" \"apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu\",\n",
" ],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros + generator_flag,\n",
" \"nvcc\": [\n",
" \"-O3\",\n",
" \"-U__CUDA_NO_HALF_OPERATORS__\",\n",
" \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n",
" \"--expt-relaxed-constexpr\",\n",
" \"--expt-extended-lambda\",\n",
" \"--use_fast_math\",\n",
" ] + version_dependent_macros + generator_flag + cc_flag,\n",
" },\n",
" include_dirs=[\n",
" os.path.join(this_dir, \"apex/contrib/csrc\"),\n",
" os.path.join(this_dir, \"apex/contrib/csrc/fmha/src\"),\n",
" ],\n",
" )\n",
" )\n",
"\n",
"\n",
"if \"--fast_multihead_attn\" in sys.argv:\n",
" sys.argv.remove(\"--fast_multihead_attn\")\n",
" raise_if_cuda_home_none(\"--fast_multihead_attn\")\n",
"\n",
" cc_flag = []\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_70,code=sm_70\")\n",
"\n",
" if bare_metal_version >= Version(\"11.0\"):\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_80,code=sm_80\")\n",
" if bare_metal_version >= Version(\"11.1\"):\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_86,code=sm_86\")\n",
" if bare_metal_version >= Version(\"11.8\"):\n",
" cc_flag.append(\"-gencode\")\n",
" cc_flag.append(\"arch=compute_90,code=sm_90\")\n",
"\n",
" subprocess.run([\"git\", \"submodule\", \"update\", \"--init\", \"apex/contrib/csrc/multihead_attn/cutlass\"])\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fast_multihead_attn\",\n",
" sources=[\n",
" \"apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp\",\n",
" \"apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu\",\n",
" \"apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu\",\n",
" \"apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu\",\n",
" \"apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu\",\n",
" \"apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu\",\n",
" \"apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu\",\n",
" \"apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu\",\n",
" \"apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu\",\n",
" ],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros + generator_flag,\n",
" \"nvcc\": [\n",
" \"-O3\",\n",
" \"-U__CUDA_NO_HALF_OPERATORS__\",\n",
" \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n",
" \"--expt-relaxed-constexpr\",\n",
" \"--expt-extended-lambda\",\n",
" \"--use_fast_math\",\n",
" ]\n",
" + version_dependent_macros\n",
" + generator_flag\n",
" + cc_flag,\n",
" },\n",
" include_dirs=[\n",
" os.path.join(this_dir, \"apex/contrib/csrc/multihead_attn/cutlass/include/\"),\n",
" os.path.join(this_dir, \"apex/contrib/csrc/multihead_attn/cutlass/tools/util/include\")\n",
" ],\n",
" )\n",
" )\n",
"\n",
"if \"--transducer\" in sys.argv:\n",
" sys.argv.remove(\"--transducer\")\n",
" raise_if_cuda_home_none(\"--transducer\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"transducer_joint_cuda\",\n",
" sources=[\n",
" \"apex/contrib/csrc/transducer/transducer_joint.cpp\",\n",
" \"apex/contrib/csrc/transducer/transducer_joint_kernel.cu\",\n",
" ],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros + generator_flag,\n",
" \"nvcc\": [\"-O3\"] + version_dependent_macros + generator_flag,\n",
" },\n",
" include_dirs=[os.path.join(this_dir, \"csrc\"), os.path.join(this_dir, \"apex/contrib/csrc/multihead_attn\")],\n",
" )\n",
" )\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"transducer_loss_cuda\",\n",
" sources=[\n",
" \"apex/contrib/csrc/transducer/transducer_loss.cpp\",\n",
" \"apex/contrib/csrc/transducer/transducer_loss_kernel.cu\",\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, \"csrc\")],\n",
" extra_compile_args={\n",
" \"cxx\": [\"-O3\"] + version_dependent_macros,\n",
" \"nvcc\": [\"-O3\"] + version_dependent_macros,\n",
" },\n",
" )\n",
" )\n",
"\n",
"if \"--cudnn_gbn\" in sys.argv:\n",
" sys.argv.remove(\"--cudnn_gbn\")\n",
" raise_if_cuda_home_none(\"--cudnn_gbn\")\n",
" if check_cudnn_version_and_warn(\"--cudnn_gbn\", 8500):\n",
" subprocess.run([\"git\", \"submodule\", \"update\", \"--init\", \"apex/contrib/csrc/cudnn-frontend/\"])\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"cudnn_gbn_lib\",\n",
" sources=[\n",
" \"apex/contrib/csrc/cudnn_gbn/norm_sample.cpp\",\n",
" \"apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp\",\n",
" ],\n",
" include_dirs=[os.path.join(this_dir, \"apex/contrib/csrc/cudnn-frontend/include\")],\n",
" extra_compile_args={\"cxx\": [\"-O3\", \"-g\"] + version_dependent_macros + generator_flag},\n",
" )\n",
" )\n",
"\n",
"if \"--peer_memory\" in sys.argv:\n",
" sys.argv.remove(\"--peer_memory\")\n",
" raise_if_cuda_home_none(\"--peer_memory\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"peer_memory_cuda\",\n",
" sources=[\n",
" \"apex/contrib/csrc/peer_memory/peer_memory_cuda.cu\",\n",
" \"apex/contrib/csrc/peer_memory/peer_memory.cpp\",\n",
" ],\n",
" extra_compile_args={\"cxx\": [\"-O3\"] + version_dependent_macros + generator_flag},\n",
" )\n",
" )\n",
"\n",
"# NOTE: Requires NCCL >= 2.10.3\n",
"if \"--nccl_p2p\" in sys.argv:\n",
" sys.argv.remove(\"--nccl_p2p\")\n",
" raise_if_cuda_home_none(\"--nccl_p2p\")\n",
" # Check NCCL version.\n",
" _nccl_version_getter = load(\n",
" name=\"_nccl_version_getter\",\n",
" sources=[\"apex/contrib/csrc/nccl_p2p/nccl_version.cpp\", \"apex/contrib/csrc/nccl_p2p/nccl_version_check.cu\"],\n",
"\n",
" )\n",
" _available_nccl_version = _nccl_version_getter.get_nccl_version()\n",
" if _available_nccl_version >= (2, 10):\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"nccl_p2p_cuda\",\n",
" sources=[\n",
" \"apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu\",\n",
" \"apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp\",\n",
" ],\n",
" extra_compile_args={\"cxx\": [\"-O3\"] + version_dependent_macros + generator_flag},\n",
" )\n",
" )\n",
" else:\n",
" warnings.warn(\n",
" f\"Skip `--nccl_p2p` as it requires NCCL 2.10.3 or later, but {_available_nccl_version[0]}.{_available_nccl_version[1]}\"\n",
" )\n",
"\n",
"# note (mkozuki): Now `--fast_bottleneck` option (i.e. apex/contrib/bottleneck) depends on `--peer_memory` and `--nccl_p2p`.\n",
"if \"--fast_bottleneck\" in sys.argv:\n",
" sys.argv.remove(\"--fast_bottleneck\")\n",
" raise_if_cuda_home_none(\"--fast_bottleneck\")\n",
" if check_cudnn_version_and_warn(\"--fast_bottleneck\", 8400):\n",
" subprocess.run([\"git\", \"submodule\", \"update\", \"--init\", \"apex/contrib/csrc/cudnn-frontend/\"])\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fast_bottleneck\",\n",
" sources=[\"apex/contrib/csrc/bottleneck/bottleneck.cpp\"],\n",
" include_dirs=[os.path.join(this_dir, \"apex/contrib/csrc/cudnn-frontend/include\")],\n",
" extra_compile_args={\"cxx\": [\"-O3\"] + version_dependent_macros + generator_flag},\n",
" )\n",
" )\n",
"\n",
"\n",
"if \"--fused_conv_bias_relu\" in sys.argv:\n",
" sys.argv.remove(\"--fused_conv_bias_relu\")\n",
" raise_if_cuda_home_none(\"--fused_conv_bias_relu\")\n",
" if check_cudnn_version_and_warn(\"--fused_conv_bias_relu\", 8400):\n",
" subprocess.run([\"git\", \"submodule\", \"update\", \"--init\", \"apex/contrib/csrc/cudnn-frontend/\"])\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"fused_conv_bias_relu\",\n",
" sources=[\"apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp\"],\n",
" include_dirs=[os.path.join(this_dir, \"apex/contrib/csrc/cudnn-frontend/include\")],\n",
" extra_compile_args={\"cxx\": [\"-O3\"] + version_dependent_macros + generator_flag},\n",
" )\n",
" )\n",
"\n",
"\n",
"if \"--gpu_direct_storage\" in sys.argv:\n",
" sys.argv.remove(\"--gpu_direct_storage\")\n",
" raise_if_cuda_home_none(\"--gpu_direct_storage\")\n",
" ext_modules.append(\n",
" CUDAExtension(\n",
" name=\"_apex_gpu_direct_storage\",\n",
" sources=[\"apex/contrib/csrc/gpu_direct_storage/gds.cpp\", \"apex/contrib/csrc/gpu_direct_storage/gds_pybind.cpp\"],\n",
" include_dirs=[os.path.join(this_dir, \"apex/contrib/csrc/gpu_direct_storage\")],\n",
" libraries=[\"cufile\"],\n",
" extra_compile_args={\"cxx\": [\"-O3\"] + version_dependent_macros + generator_flag},\n",
" )\n",
" )\n",
"\n",
"\n",
"setup(\n",
" name=\"apex\",\n",
" version=\"0.1\",\n",
" packages=find_packages(\n",
" exclude=(\"build\", \"csrc\", \"include\", \"tests\", \"dist\", \"docs\", \"tests\", \"examples\", \"apex.egg-info\",)\n",
" ),\n",
" install_requires=[\"packaging>20.6\"],\n",
" description=\"PyTorch Extensions written by NVIDIA\",\n",
" ext_modules=ext_modules,\n",
" cmdclass={\"build_ext\": BuildExtension} if ext_modules else {},\n",
" extras_require=extras,\n",
")\n",
"\n",
"\"\"\""
],
"metadata": {
"id": "NoSS6GF7jIGx"
},
"execution_count": 41,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### continue"
],
"metadata": {
"id": "I4KGHXQajWVN"
}
},
{
"cell_type": "code",
"source": [
"# Open a file in write mode\n",
"file = open(\"/content/apex/setup.py\", \"w\")\n",
"\n",
"# Write content to the file\n",
"file.write(setup_py)\n",
"\n",
"# Close the file\n",
"file.close()"
],
"metadata": {
"id": "8PSAWo4pi3Hm"
},
"execution_count": 42,
"outputs": []
},
{
"cell_type": "code",
"source": [
"## warning takes more than 28 minutes to run this. go take a break.\n",
"%cd /content/apex\n",
"!pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings \"--build-option=--cpp_ext\" --config-settings \"--build-option=--cuda_ext\" ./"
],
"metadata": {
"id": "3DJfkdf8Yy7T"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## install open-sora"
],
"metadata": {
"id": "IArkaTtzjqu9"
}
},
{
"cell_type": "code",
"source": [
"%%shell\n",
"cd Open-Sora\n",
"pip install -v ."
],
"metadata": {
"id": "Jv5Krx34OVRB"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## download and save opensora and t5 large weights"
],
"metadata": {
"id": "BcJPY6CakQKE"
}
},
{
"cell_type": "markdown",
"source": [
"official implementation uses [t5-v1_1-xxl](https://huggingface.co/DeepFloyd/t5-v1_1-xxl/blob/main/config.json) model but i'm using a t5-v1_1-large model, but it wont work. cause output dimensions of xxl(4096) and large(1024) model are different. and i think the opensora is trained with xxl model so it only takes 4096 dim input text embedding, so i don't think we can swap the models.\n",
"\n",
"you can check `d_model` parameter for large and xxl model in the config.json files below.\n",
"- https://huggingface.co/google/t5-v1_1-large/blob/main/config.json\n",
"- https://huggingface.co/DeepFloyd/t5-v1_1-xxl/blob/main/config.json"
],
"metadata": {
"id": "jjfFCwyOkZas"
}
},
{
"cell_type": "code",
"source": [
"!wget https://huggingface.co/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth?download=true -O OpenSora-v1-16x256x256.pth"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "55Hwb7AIM06c",
"outputId": "3c7374c6-51ad-4e7c-f720-36eafdbbfc99"
},
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2024-03-18 01:43:40-- https://huggingface.co/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth?download=true\n",
"Resolving huggingface.co (huggingface.co)... 18.164.174.17, 18.164.174.55, 18.164.174.118, ...\n",
"Connecting to huggingface.co (huggingface.co)|18.164.174.17|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://cdn-lfs-us-1.huggingface.co/repos/cf/4d/cf4d47208bb82c4834e214fbb6e9d8a68ccf8b671974432452891f849d57b965/50dfba596b8bd0da9a77d36142a32d5f4e0db826fb97668bd2be034d48a6fd39?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27OpenSora-v1-16x256x256.pth%3B+filename%3D%22OpenSora-v1-16x256x256.pth%22%3B&Expires=1710984863&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxMDk4NDg2M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2NmLzRkL2NmNGQ0NzIwOGJiODJjNDgzNGUyMTRmYmI2ZTlkOGE2OGNjZjhiNjcxOTc0NDMyNDUyODkxZjg0OWQ1N2I5NjUvNTBkZmJhNTk2YjhiZDBkYTlhNzdkMzYxNDJhMzJkNWY0ZTBkYjgyNmZiOTc2NjhiZDJiZTAzNGQ0OGE2ZmQzOT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=LolgD%7Ey-cdMWXA5xbYSCFxWoNaky2bxIzGzbodTeiacnudbsEeBnVCOhNcKJAsxveAyyD2zr5LKOx9sWhIGnWWLRxRkhMl35rzajP3dCoy9xrzwut6d4CKMlcb49NYtMQNwCJt5kOPLoKfFf1Z2jxjqM3wjdKWp8bzzsfCSjIRk3JLoqce3KeUCl1AVPfFC-O82Tp98XqlpyLLX6iLFNfKs%7Ec6wJ8VTw-uZ46I4TWmyLmJWOfMamqePHtXyAj-lQWY58kbfdsMNIoDEFe1FpTlmfIf01B6sN5gyCPFaQ%7EEakZm9nQnwaIKTjvbbQqkQMl1zQZKSsXPpMY7BnzX9nVw__&Key-Pair-Id=KCD77M1F0VK2B [following]\n",
"--2024-03-18 01:43:40-- https://cdn-lfs-us-1.huggingface.co/repos/cf/4d/cf4d47208bb82c4834e214fbb6e9d8a68ccf8b671974432452891f849d57b965/50dfba596b8bd0da9a77d36142a32d5f4e0db826fb97668bd2be034d48a6fd39?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27OpenSora-v1-16x256x256.pth%3B+filename%3D%22OpenSora-v1-16x256x256.pth%22%3B&Expires=1710984863&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxMDk4NDg2M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2NmLzRkL2NmNGQ0NzIwOGJiODJjNDgzNGUyMTRmYmI2ZTlkOGE2OGNjZjhiNjcxOTc0NDMyNDUyODkxZjg0OWQ1N2I5NjUvNTBkZmJhNTk2YjhiZDBkYTlhNzdkMzYxNDJhMzJkNWY0ZTBkYjgyNmZiOTc2NjhiZDJiZTAzNGQ0OGE2ZmQzOT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=LolgD%7Ey-cdMWXA5xbYSCFxWoNaky2bxIzGzbodTeiacnudbsEeBnVCOhNcKJAsxveAyyD2zr5LKOx9sWhIGnWWLRxRkhMl35rzajP3dCoy9xrzwut6d4CKMlcb49NYtMQNwCJt5kOPLoKfFf1Z2jxjqM3wjdKWp8bzzsfCSjIRk3JLoqce3KeUCl1AVPfFC-O82Tp98XqlpyLLX6iLFNfKs%7Ec6wJ8VTw-uZ46I4TWmyLmJWOfMamqePHtXyAj-lQWY58kbfdsMNIoDEFe1FpTlmfIf01B6sN5gyCPFaQ%7EEakZm9nQnwaIKTjvbbQqkQMl1zQZKSsXPpMY7BnzX9nVw__&Key-Pair-Id=KCD77M1F0VK2B\n",
"Resolving cdn-lfs-us-1.huggingface.co (cdn-lfs-us-1.huggingface.co)... 18.154.206.76, 18.154.206.94, 18.154.206.42, ...\n",
"Connecting to cdn-lfs-us-1.huggingface.co (cdn-lfs-us-1.huggingface.co)|18.154.206.76|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 3041879870 (2.8G) [application/zip]\n",
"Saving to: ‘OpenSora-v1-16x256x256.pth’\n",
"\n",
"OpenSora-v1-16x256x 100%[===================>] 2.83G 97.0MB/s in 26s \n",
"\n",
"2024-03-18 01:44:06 (114 MB/s) - ‘OpenSora-v1-16x256x256.pth’ saved [3041879870/3041879870]\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!mkdir /content/Open-Sora/pretrained_models\n",
"!mkdir /content/Open-Sora/pretrained_models/t5_ckpts\n",
"!mkdir /content/Open-Sora/pretrained_models/t5_ckpts/t5-v1_1-large"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FqyMTQr3ReMn",
"outputId": "d471bbb8-14d8-4d63-b6b4-42d8adefa7b5"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"mkdir: cannot create directory ‘/content/Open-Sora/pretrained_models’: File exists\n",
"mkdir: cannot create directory ‘/content/Open-Sora/pretrained_models/t5_ckpts’: File exists\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"## download t5 model. we download smaller model than the recommended for smaller memory usage\n",
"%cd /content/Open-Sora/pretrained_models/t5_ckpts/t5-v1_1-large\n",
"!wget https://huggingface.co/google/t5-v1_1-large/resolve/main/pytorch_model.bin?download=true -O pytorch_model.bin\n",
"%cd /content/Open-Sora"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tSmmxSH4QtTY",
"outputId": "c079aaa8-cab5-4b6b-82b1-48f938054998"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/Open-Sora/pretrained_models/t5_ckpts/t5-v1_1-large\n",
"--2024-03-18 01:57:38-- https://huggingface.co/google/t5-v1_1-large/resolve/main/pytorch_model.bin?download=true\n",
"Resolving huggingface.co (huggingface.co)... 18.164.174.118, 18.164.174.17, 18.164.174.55, ...\n",
"Connecting to huggingface.co (huggingface.co)|18.164.174.118|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://cdn-lfs.huggingface.co/google/t5-v1_1-large/329243624cf70001991b9f0410d222a618bd33188eadc9890259b60cbc78f944?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27pytorch_model.bin%3B+filename%3D%22pytorch_model.bin%22%3B&response-content-type=application%2Foctet-stream&Expires=1710986258&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxMDk4NjI1OH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9nb29nbGUvdDUtdjFfMS1sYXJnZS8zMjkyNDM2MjRjZjcwMDAxOTkxYjlmMDQxMGQyMjJhNjE4YmQzMzE4OGVhZGM5ODkwMjU5YjYwY2JjNzhmOTQ0P3Jlc3BvbnNlLWNvbnRlbnQtZGlzcG9zaXRpb249KiZyZXNwb25zZS1jb250ZW50LXR5cGU9KiJ9XX0_&Signature=ACkNs%7ECRBhX0pyL3sb9c99773A4Xgf00buwRSqGIzTy8L696c2bjT-7Fj4GMk7M89IyL0jDPtHVnxahIceSLl0uRqAjgy6GtpWCQlMdbzrToiuH10-fb35%7EESIeds82zz1sWQGct26r3XUcsNX6LxfoQ7Qoavgx2VMbcGapytqDDPrSpohUoRo%7EClY5Z16U1EaaC1cEZgpW3cxBJ%7EFJdXT%7EQJk0Mfjqnb2taFyryzshJ7GZ1u95lQHkkQ%7EHKK5jOlEldF9%7EHsUhsJGrlKUsjwi-eYxc5nBrDRnhhBHkiMW5GNwYps-A6qrGwTrRI9WVfj%7Ebb1JwAR0X%7Eo1ivZwtakQ__&Key-Pair-Id=KVTP0A1DKRTAX [following]\n",
"--2024-03-18 01:57:38-- https://cdn-lfs.huggingface.co/google/t5-v1_1-large/329243624cf70001991b9f0410d222a618bd33188eadc9890259b60cbc78f944?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27pytorch_model.bin%3B+filename%3D%22pytorch_model.bin%22%3B&response-content-type=application%2Foctet-stream&Expires=1710986258&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxMDk4NjI1OH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9nb29nbGUvdDUtdjFfMS1sYXJnZS8zMjkyNDM2MjRjZjcwMDAxOTkxYjlmMDQxMGQyMjJhNjE4YmQzMzE4OGVhZGM5ODkwMjU5YjYwY2JjNzhmOTQ0P3Jlc3BvbnNlLWNvbnRlbnQtZGlzcG9zaXRpb249KiZyZXNwb25zZS1jb250ZW50LXR5cGU9KiJ9XX0_&Signature=ACkNs%7ECRBhX0pyL3sb9c99773A4Xgf00buwRSqGIzTy8L696c2bjT-7Fj4GMk7M89IyL0jDPtHVnxahIceSLl0uRqAjgy6GtpWCQlMdbzrToiuH10-fb35%7EESIeds82zz1sWQGct26r3XUcsNX6LxfoQ7Qoavgx2VMbcGapytqDDPrSpohUoRo%7EClY5Z16U1EaaC1cEZgpW3cxBJ%7EFJdXT%7EQJk0Mfjqnb2taFyryzshJ7GZ1u95lQHkkQ%7EHKK5jOlEldF9%7EHsUhsJGrlKUsjwi-eYxc5nBrDRnhhBHkiMW5GNwYps-A6qrGwTrRI9WVfj%7Ebb1JwAR0X%7Eo1ivZwtakQ__&Key-Pair-Id=KVTP0A1DKRTAX\n",
"Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.14, 18.154.206.17, 18.154.206.4, ...\n",
"Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.14|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 3132858253 (2.9G) [application/octet-stream]\n",
"Saving to: ‘pytorch_model.bin’\n",
"\n",
"pytorch_model.bin 100%[===================>] 2.92G 50.4MB/s in 59s \n",
"\n",
"2024-03-18 01:58:37 (50.9 MB/s) - ‘pytorch_model.bin’ saved [3132858253/3132858253]\n",
"\n",
"/content/Open-Sora\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%cd /content/Open-Sora/pretrained_models/t5_ckpts/t5-v1_1-large"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lGDf9_FgS0Tv",
"outputId": "ae9448bc-29d4-4522-d094-6688273234e7"
},
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/Open-Sora/pretrained_models/t5_ckpts/t5-v1_1-large\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%%shell\n",
"# tokenizer_config.json\n",
"wget https://huggingface.co/google/t5-v1_1-large/resolve/main/tokenizer_config.json -O tokenizer_config.json\n",
"\n",
"# special_tokens_map.json\n",
"wget https://huggingface.co/google/t5-v1_1-large/resolve/main/special_tokens_map.json -O special_tokens_map.json\n",
"\n",
"# generation_config.json\n",
"wget https://huggingface.co/google/t5-v1_1-large/resolve/main/generation_config.json -O generation_config.json\n",
"\n",
"# config.json\n",
"wget https://huggingface.co/google/t5-v1_1-large/resolve/main/config.json -O config.json\n",
"wget https://huggingface.co/google/t5-v1_1-large/resolve/main/spiece.model -O spiece.model\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pKSDgzvmS4PN",
"outputId": "5bc43f76-cc1d-4143-f1e8-221fb938f0d8"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2024-03-18 02:02:32-- https://huggingface.co/google/t5-v1_1-large/resolve/main/tokenizer_config.json\n",
"Resolving huggingface.co (huggingface.co)... 18.164.174.118, 18.164.174.55, 18.164.174.17, ...\n",
"Connecting to huggingface.co (huggingface.co)|18.164.174.118|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1857 (1.8K) [text/plain]\n",
"Saving to: ‘tokenizer_config.json’\n",
"\n",
"\rtokenizer_config.js 0%[ ] 0 --.-KB/s \rtokenizer_config.js 100%[===================>] 1.81K --.-KB/s in 0s \n",
"\n",
"2024-03-18 02:02:32 (1.20 GB/s) - ‘tokenizer_config.json’ saved [1857/1857]\n",
"\n",
"--2024-03-18 02:02:32-- https://huggingface.co/google/t5-v1_1-large/resolve/main/special_tokens_map.json\n",
"Resolving huggingface.co (huggingface.co)... 18.164.174.118, 18.164.174.55, 18.164.174.17, ...\n",
"Connecting to huggingface.co (huggingface.co)|18.164.174.118|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1786 (1.7K) [text/plain]\n",
"Saving to: ‘special_tokens_map.json’\n",
"\n",
"special_tokens_map. 100%[===================>] 1.74K --.-KB/s in 0s \n",
"\n",
"2024-03-18 02:02:33 (1.48 GB/s) - ‘special_tokens_map.json’ saved [1786/1786]\n",
"\n",
"--2024-03-18 02:02:33-- https://huggingface.co/google/t5-v1_1-large/resolve/main/generation_config.json\n",
"Resolving huggingface.co (huggingface.co)... 18.164.174.118, 18.164.174.55, 18.164.174.17, ...\n",
"Connecting to huggingface.co (huggingface.co)|18.164.174.118|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 147 [text/plain]\n",
"Saving to: ‘generation_config.json’\n",
"\n",
"generation_config.j 100%[===================>] 147 --.-KB/s in 0s \n",
"\n",
"2024-03-18 02:02:33 (111 MB/s) - ‘generation_config.json’ saved [147/147]\n",
"\n",
"--2024-03-18 02:02:33-- https://huggingface.co/google/t5-v1_1-large/resolve/main/config.json\n",
"Resolving huggingface.co (huggingface.co)... 18.164.174.118, 18.164.174.55, 18.164.174.17, ...\n",
"Connecting to huggingface.co (huggingface.co)|18.164.174.118|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 607 [text/plain]\n",
"Saving to: ‘config.json’\n",
"\n",
"config.json 100%[===================>] 607 --.-KB/s in 0s \n",
"\n",
"2024-03-18 02:02:33 (499 MB/s) - ‘config.json’ saved [607/607]\n",
"\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": []
},
"metadata": {},
"execution_count": 15
}
]
},
{
"cell_type": "markdown",
"source": [
"## Run inference"
],
"metadata": {
"id": "DbZvWn-2kVKr"
}
},
{
"cell_type": "code",
"source": [
"%cd /content/Open-Sora"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "656w8J7DS3wN",
"outputId": "ee92ecef-8684-4665-9b0b-b2af45f8f314"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/Open-Sora\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%cd /content/Open-Sora\n",
"!torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path /content/OpenSora-v1-16x256x256.pth\n",
"## this doesn't work cause I used t5-large which has 1024 dims and the open-sora model takes 4096 dim input from xxl model used by the main creators.\n",
"## we can't use the xxl models cause its too big for colab t4 gpus with only 14gb vram."
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PpNNvKPPOpdr",
"outputId": "ddaca094-91d1-4ab4-abc3-70f5d3b2637c"
},
"execution_count": 48,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/Open-Sora\n",
"/usr/local/lib/python3.10/dist-packages/colossalai/pipeline/schedule/_utils.py:19: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" _register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)\n",
"/usr/local/lib/python3.10/dist-packages/torch/utils/_pytree.py:254: UserWarning: <class 'collections.OrderedDict'> is already registered as pytree node. Overwriting the previous registration.\n",
" warnings.warn(\n",
"2024-03-18 03:18:01.946148: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-03-18 03:18:01.946206: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-03-18 03:18:01.948655: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-03-18 03:18:03.212411: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"Config (path: configs/opensora/inference/16x256x256.py): {'num_frames': 16, 'fps': 8, 'image_size': (256, 256), 'model': {'type': 'STDiT-XL/2', 'space_scale': 0.5, 'time_scale': 1.0, 'enable_flashattn': True, 'enable_layernorm_kernel': True, 'from_pretrained': '/content/OpenSora-v1-16x256x256.pth'}, 'vae': {'type': 'VideoAutoencoderKL', 'from_pretrained': 'stabilityai/sd-vae-ft-ema'}, 'text_encoder': {'type': 't5', 'from_pretrained': './pretrained_models/t5_ckpts', 'model_max_length': 120}, 'scheduler': {'type': 'iddpm', 'num_sampling_steps': 100, 'cfg_scale': 7.0}, 'dtype': 'fp16', 'batch_size': 2, 'seed': 42, 'prompt_path': './assets/texts/t2v_samples.txt', 'save_dir': './outputs/samples/', 'multi_resolution': False}\n",
"/usr/local/lib/python3.10/dist-packages/colossalai/initialize.py:48: UserWarning: `config` is deprecated and will be removed soon.\n",
" warnings.warn(\"`config` is deprecated and will be removed soon.\")\n",
"\u001b[2;36m[03/18/24 03:18:04]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - colossalai - INFO: \n",
"\u001b[2;36m \u001b[0m \u001b[35m/usr/local/lib/python3.10/dist-packages/colossalai/\u001b[0m\u001b[95minitialize.py\u001b[0m:\u001b[1;36m67\u001b[0m \n",
"\u001b[2;36m \u001b[0m launch \n",
"\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m colossalai - colossalai - INFO: Distributed environment is initialized,\n",
"\u001b[2;36m \u001b[0m world size: \u001b[1;36m1\u001b[0m \n",
"./pretrained_models/t5_ckpts/t5-v1_1-large\n",
"Traceback (most recent call last):\n",
" File \"/content/Open-Sora/scripts/inference.py\", line 112, in <module>\n",
" main()\n",
" File \"/content/Open-Sora/scripts/inference.py\", line 58, in main\n",
" model = build_module(\n",
" File \"/usr/local/lib/python3.10/dist-packages/opensora/registry.py\", line 22, in build_module\n",
" return builder.build(cfg)\n",
" File \"/usr/local/lib/python3.10/dist-packages/mmengine/registry/registry.py\", line 570, in build\n",
" return self.build_func(cfg, *args, **kwargs, registry=self)\n",
" File \"/usr/local/lib/python3.10/dist-packages/mmengine/registry/build_functions.py\", line 121, in build_from_cfg\n",
" obj = obj_cls(**args) # type: ignore\n",
" File \"/usr/local/lib/python3.10/dist-packages/opensora/models/stdit/stdit.py\", line 387, in STDiT_XL_2\n",
" load_checkpoint(model, from_pretrained)\n",
" File \"/usr/local/lib/python3.10/dist-packages/opensora/utils/ckpt_utils.py\", line 206, in load_checkpoint\n",
" missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)\n",
" File \"/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\", line 2153, in load_state_dict\n",
" raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n",
"RuntimeError: Error(s) in loading state_dict for STDiT:\n",
"\tsize mismatch for y_embedder.y_embedding: copying a param with shape torch.Size([120, 4096]) from checkpoint, the shape in current model is torch.Size([120, 1024]).\n",
"\tsize mismatch for y_embedder.y_proj.fc1.weight: copying a param with shape torch.Size([1152, 4096]) from checkpoint, the shape in current model is torch.Size([1152, 1024]).\n",
"[2024-03-18 03:18:50,045] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 27322) of binary: /usr/bin/python3\n",
"Traceback (most recent call last):\n",
" File \"/usr/local/bin/torchrun\", line 8, in <module>\n",
" sys.exit(main())\n",
" File \"/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py\", line 347, in wrapper\n",
" return f(*args, **kwargs)\n",
" File \"/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py\", line 812, in main\n",
" run(args)\n",
" File \"/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py\", line 803, in run\n",
" elastic_launch(\n",
" File \"/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py\", line 135, in __call__\n",
" return launch_agent(self._config, self._entrypoint, list(args))\n",
" File \"/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py\", line 268, in launch_agent\n",
" raise ChildFailedError(\n",
"torch.distributed.elastic.multiprocessing.errors.ChildFailedError: \n",
"============================================================\n",
"scripts/inference.py FAILED\n",
"------------------------------------------------------------\n",
"Failures:\n",
" <NO_OTHER_FAILURES>\n",
"------------------------------------------------------------\n",
"Root Cause (first observed failure):\n",
"[0]:\n",
" time : 2024-03-18_03:18:50\n",
" host : 6314233c8d07\n",
" rank : 0 (local_rank: 0)\n",
" exitcode : 1 (pid: 27322)\n",
" error_file: <N/A>\n",
" traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html\n",
"============================================================\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "WVN-1YsgauZX"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment