Skip to content

Instantly share code, notes, and snippets.

@tmm1
Last active December 7, 2023 00:49
Show Gist options
  • Save tmm1/212d6141887890af41f1fdf5c73282f2 to your computer and use it in GitHub Desktop.
Save tmm1/212d6141887890af41f1fdf5c73282f2 to your computer and use it in GitHub Desktop.
torch memprof example using hacked up line_profiler
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "41ddb4ee-66ea-4876-8bb4-38f91f6900c3",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu118\n",
"Requirement already satisfied: torch in /home/tmm1/.local/lib/python3.10/site-packages (2.1.1+cu118)\n",
"Requirement already satisfied: filelock in /home/tmm1/.local/lib/python3.10/site-packages (from torch) (3.13.1)\n",
"Requirement already satisfied: typing-extensions in /home/tmm1/.local/lib/python3.10/site-packages (from torch) (4.8.0)\n",
"Requirement already satisfied: sympy in /home/tmm1/.local/lib/python3.10/site-packages (from torch) (1.12)\n",
"Requirement already satisfied: networkx in /home/tmm1/.local/lib/python3.10/site-packages (from torch) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages (from torch) (3.1.2)\n",
"Requirement already satisfied: fsspec in /home/tmm1/.local/lib/python3.10/site-packages (from torch) (2023.12.1)\n",
"Requirement already satisfied: triton==2.1.0 in /home/tmm1/.local/lib/python3.10/site-packages (from torch) (2.1.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages (from jinja2->torch) (2.1.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /home/tmm1/.local/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n"
]
}
],
"source": [
"!pip install torch --extra-index-url https://download.pytorch.org/whl/cu118"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0a098262-4e94-478c-a1c3-8a4199b76db4",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found existing installation: line-profiler 4.1.2\n",
"Uninstalling line-profiler-4.1.2:\n",
" Successfully uninstalled line-profiler-4.1.2\n",
"Using pip 23.3.1 from /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages/pip (python 3.10)\n",
"Collecting git+https://github.com/tmm1/line_profiler@torch-memprof\n",
" Cloning https://github.com/tmm1/line_profiler (to revision torch-memprof) to /tmp/pip-req-build-c8tg5mcf\n",
" Running command git version\n",
" git version 2.34.1\n",
" Running command git clone --filter=blob:none https://github.com/tmm1/line_profiler /tmp/pip-req-build-c8tg5mcf\n",
" Cloning into '/tmp/pip-req-build-c8tg5mcf'...\n",
" Running command git show-ref torch-memprof\n",
" a41c968ad3e446aa8eee9c7ee8a398a235dd9d3a refs/remotes/origin/torch-memprof\n",
" Running command git symbolic-ref -q HEAD\n",
" refs/heads/main\n",
" Running command git checkout -b torch-memprof --track origin/torch-memprof\n",
" Switched to a new branch 'torch-memprof'\n",
" Branch 'torch-memprof' set up to track remote branch 'torch-memprof' from 'origin'.\n",
" Resolved https://github.com/tmm1/line_profiler to commit a41c968ad3e446aa8eee9c7ee8a398a235dd9d3a\n",
" Running command git rev-parse HEAD\n",
" a41c968ad3e446aa8eee9c7ee8a398a235dd9d3a\n",
" Running command Preparing metadata (pyproject.toml)\n",
" warning: line_profiler/_line_profiler.pyx:112:70: Implicit noexcept declaration is deprecated. Function declaration should contain 'noexcept' keyword.\n",
" warning: line_profiler/_line_profiler.pyx:135:34: Implicit noexcept declaration is deprecated. Function declaration should contain 'noexcept' keyword.\n",
" warning: line_profiler/_line_profiler.pyx:222:34: Implicit noexcept declaration is deprecated. Function declaration should contain 'noexcept' keyword.\n",
" warning: line_profiler/_line_profiler.pyx:365:23: Implicit noexcept declaration is deprecated. Function declaration should contain 'noexcept' keyword.\n",
" Compiling line_profiler/_line_profiler.pyx because it changed.\n",
" [1/1] Cythonizing line_profiler/_line_profiler.pyx\n",
" running dist_info\n",
" creating /tmp/pip-modern-metadata-ufo6orkt/line_profiler.egg-info\n",
" writing /tmp/pip-modern-metadata-ufo6orkt/line_profiler.egg-info/PKG-INFO\n",
" writing dependency_links to /tmp/pip-modern-metadata-ufo6orkt/line_profiler.egg-info/dependency_links.txt\n",
" writing entry points to /tmp/pip-modern-metadata-ufo6orkt/line_profiler.egg-info/entry_points.txt\n",
" writing requirements to /tmp/pip-modern-metadata-ufo6orkt/line_profiler.egg-info/requires.txt\n",
" writing top-level names to /tmp/pip-modern-metadata-ufo6orkt/line_profiler.egg-info/top_level.txt\n",
" writing manifest file '/tmp/pip-modern-metadata-ufo6orkt/line_profiler.egg-info/SOURCES.txt'\n",
" file line_profiler.py (for module line_profiler) not found\n",
" reading manifest file '/tmp/pip-modern-metadata-ufo6orkt/line_profiler.egg-info/SOURCES.txt'\n",
" reading manifest template 'MANIFEST.in'\n",
" warning: no files found matching '*.md'\n",
" warning: no files found matching 'run_tests.sh'\n",
" warning: no files found matching '*.pyd' under directory 'line_profiler'\n",
" adding license file 'LICENSE.txt'\n",
" adding license file 'LICENSE_Python.txt'\n",
" writing manifest file '/tmp/pip-modern-metadata-ufo6orkt/line_profiler.egg-info/SOURCES.txt'\n",
" creating '/tmp/pip-modern-metadata-ufo6orkt/line_profiler-4.1.2.dist-info'\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Building wheels for collected packages: line-profiler\n",
" Running command git rev-parse HEAD\n",
" a41c968ad3e446aa8eee9c7ee8a398a235dd9d3a\n",
" Running command Building wheel for line-profiler (pyproject.toml)\n",
" running bdist_wheel\n",
" running build\n",
" running build_py\n",
" file line_profiler.py (for module line_profiler) not found\n",
" creating build\n",
" creating build/lib.linux-x86_64-cpython-310\n",
" copying kernprof.py -> build/lib.linux-x86_64-cpython-310\n",
" creating build/lib.linux-x86_64-cpython-310/line_profiler\n",
" copying line_profiler/__main__.py -> build/lib.linux-x86_64-cpython-310/line_profiler\n",
" copying line_profiler/explicit_profiler.py -> build/lib.linux-x86_64-cpython-310/line_profiler\n",
" copying line_profiler/line_profiler.py -> build/lib.linux-x86_64-cpython-310/line_profiler\n",
" copying line_profiler/ipython_extension.py -> build/lib.linux-x86_64-cpython-310/line_profiler\n",
" copying line_profiler/__init__.py -> build/lib.linux-x86_64-cpython-310/line_profiler\n",
" creating build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile\n",
" copying line_profiler/autoprofile/util_static.py -> build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile\n",
" copying line_profiler/autoprofile/autoprofile.py -> build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile\n",
" copying line_profiler/autoprofile/ast_tree_profiler.py -> build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile\n",
" copying line_profiler/autoprofile/profmod_extractor.py -> build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile\n",
" copying line_profiler/autoprofile/ast_profle_transformer.py -> build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile\n",
" copying line_profiler/autoprofile/line_profiler_utils.py -> build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile\n",
" copying line_profiler/autoprofile/__init__.py -> build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile\n",
" copying line_profiler/py.typed -> build/lib.linux-x86_64-cpython-310/line_profiler\n",
" copying line_profiler/ipython_extension.pyi -> build/lib.linux-x86_64-cpython-310/line_profiler\n",
" copying line_profiler/explicit_profiler.pyi -> build/lib.linux-x86_64-cpython-310/line_profiler\n",
" copying line_profiler/line_profiler.pyi -> build/lib.linux-x86_64-cpython-310/line_profiler\n",
" copying line_profiler/__main__.pyi -> build/lib.linux-x86_64-cpython-310/line_profiler\n",
" file line_profiler.py (for module line_profiler) not found\n",
" running build_ext\n",
" building 'line_profiler._line_profiler' extension\n",
" creating build/temp.linux-x86_64-cpython-310\n",
" creating build/temp.linux-x86_64-cpython-310/line_profiler\n",
" gcc -pthread -B /home/tmm1/micromamba/envs/jupyter/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /home/tmm1/micromamba/envs/jupyter/include -fPIC -O2 -isystem /home/tmm1/micromamba/envs/jupyter/include -fPIC -DCYTHON_TRACE=0 -Ipython25.pxd -Iline_profiler -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include/TH -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include/THC -I/home/tmm1/micromamba/envs/jupyter/include/python3.10 -c line_profiler/_line_profiler.cpp -o build/temp.linux-x86_64-cpython-310/line_profiler/_line_profiler.o\n",
" gcc -pthread -B /home/tmm1/micromamba/envs/jupyter/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /home/tmm1/micromamba/envs/jupyter/include -fPIC -O2 -isystem /home/tmm1/micromamba/envs/jupyter/include -fPIC -DCYTHON_TRACE=0 -Ipython25.pxd -Iline_profiler -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include/TH -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include/THC -I/home/tmm1/micromamba/envs/jupyter/include/python3.10 -c line_profiler/timers.c -o build/temp.linux-x86_64-cpython-310/line_profiler/timers.o\n",
" gcc -pthread -B /home/tmm1/micromamba/envs/jupyter/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /home/tmm1/micromamba/envs/jupyter/include -fPIC -O2 -isystem /home/tmm1/micromamba/envs/jupyter/include -fPIC -DCYTHON_TRACE=0 -Ipython25.pxd -Iline_profiler -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include/TH -I/home/tmm1/.local/lib/python3.10/site-packages/torch/include/THC -I/home/tmm1/micromamba/envs/jupyter/include/python3.10 -c line_profiler/unset_trace.c -o build/temp.linux-x86_64-cpython-310/line_profiler/unset_trace.o\n",
" g++ -pthread -B /home/tmm1/micromamba/envs/jupyter/compiler_compat -shared -Wl,--allow-shlib-undefined -Wl,-rpath,/home/tmm1/micromamba/envs/jupyter/lib -Wl,-rpath-link,/home/tmm1/micromamba/envs/jupyter/lib -L/home/tmm1/micromamba/envs/jupyter/lib -Wl,--allow-shlib-undefined -Wl,-rpath,/home/tmm1/micromamba/envs/jupyter/lib -Wl,-rpath-link,/home/tmm1/micromamba/envs/jupyter/lib -L/home/tmm1/micromamba/envs/jupyter/lib build/temp.linux-x86_64-cpython-310/line_profiler/_line_profiler.o build/temp.linux-x86_64-cpython-310/line_profiler/timers.o build/temp.linux-x86_64-cpython-310/line_profiler/unset_trace.o -L/home/tmm1/.local/lib/python3.10/site-packages/torch/lib -Wl,-R/home/tmm1/.local/lib/python3.10/site-packages/torch/lib -lc10_cuda -o build/lib.linux-x86_64-cpython-310/line_profiler/_line_profiler.cpython-310-x86_64-linux-gnu.so\n",
" installing to build/bdist.linux-x86_64/wheel\n",
" running install\n",
" running install_lib\n",
" creating build/bdist.linux-x86_64\n",
" creating build/bdist.linux-x86_64/wheel\n",
" copying build/lib.linux-x86_64-cpython-310/kernprof.py -> build/bdist.linux-x86_64/wheel\n",
" creating build/bdist.linux-x86_64/wheel/line_profiler\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/__main__.py -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/ipython_extension.pyi -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" creating build/bdist.linux-x86_64/wheel/line_profiler/autoprofile\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile/util_static.py -> build/bdist.linux-x86_64/wheel/line_profiler/autoprofile\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile/autoprofile.py -> build/bdist.linux-x86_64/wheel/line_profiler/autoprofile\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile/ast_tree_profiler.py -> build/bdist.linux-x86_64/wheel/line_profiler/autoprofile\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile/profmod_extractor.py -> build/bdist.linux-x86_64/wheel/line_profiler/autoprofile\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile/ast_profle_transformer.py -> build/bdist.linux-x86_64/wheel/line_profiler/autoprofile\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile/line_profiler_utils.py -> build/bdist.linux-x86_64/wheel/line_profiler/autoprofile\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/autoprofile/__init__.py -> build/bdist.linux-x86_64/wheel/line_profiler/autoprofile\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/explicit_profiler.py -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/explicit_profiler.pyi -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/line_profiler.pyi -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/py.typed -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/line_profiler.py -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/__main__.pyi -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/ipython_extension.py -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/_line_profiler.cpython-310-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" copying build/lib.linux-x86_64-cpython-310/line_profiler/__init__.py -> build/bdist.linux-x86_64/wheel/line_profiler\n",
" running install_egg_info\n",
" running egg_info\n",
" creating line_profiler.egg-info\n",
" writing line_profiler.egg-info/PKG-INFO\n",
" writing dependency_links to line_profiler.egg-info/dependency_links.txt\n",
" writing entry points to line_profiler.egg-info/entry_points.txt\n",
" writing requirements to line_profiler.egg-info/requires.txt\n",
" writing top-level names to line_profiler.egg-info/top_level.txt\n",
" writing manifest file 'line_profiler.egg-info/SOURCES.txt'\n",
" file line_profiler.py (for module line_profiler) not found\n",
" reading manifest file 'line_profiler.egg-info/SOURCES.txt'\n",
" reading manifest template 'MANIFEST.in'\n",
" warning: no files found matching '*.md'\n",
" warning: no files found matching 'run_tests.sh'\n",
" warning: no files found matching '*.pyd' under directory 'line_profiler'\n",
" adding license file 'LICENSE.txt'\n",
" adding license file 'LICENSE_Python.txt'\n",
" writing manifest file 'line_profiler.egg-info/SOURCES.txt'\n",
" Copying line_profiler.egg-info to build/bdist.linux-x86_64/wheel/line_profiler-4.1.2-py3.10.egg-info\n",
" running install_scripts\n",
" creating build/bdist.linux-x86_64/wheel/line_profiler-4.1.2.dist-info/WHEEL\n",
" creating '/tmp/pip-wheel-hyu1tqo_/.tmp-i5cwcjy5/line_profiler-4.1.2-cp310-cp310-linux_x86_64.whl' and adding 'build/bdist.linux-x86_64/wheel' to it\n",
" adding 'kernprof.py'\n",
" adding 'line_profiler/__init__.py'\n",
" adding 'line_profiler/__main__.py'\n",
" adding 'line_profiler/__main__.pyi'\n",
" adding 'line_profiler/_line_profiler.cpython-310-x86_64-linux-gnu.so'\n",
" adding 'line_profiler/explicit_profiler.py'\n",
" adding 'line_profiler/explicit_profiler.pyi'\n",
" adding 'line_profiler/ipython_extension.py'\n",
" adding 'line_profiler/ipython_extension.pyi'\n",
" adding 'line_profiler/line_profiler.py'\n",
" adding 'line_profiler/line_profiler.pyi'\n",
" adding 'line_profiler/py.typed'\n",
" adding 'line_profiler/autoprofile/__init__.py'\n",
" adding 'line_profiler/autoprofile/ast_profle_transformer.py'\n",
" adding 'line_profiler/autoprofile/ast_tree_profiler.py'\n",
" adding 'line_profiler/autoprofile/autoprofile.py'\n",
" adding 'line_profiler/autoprofile/line_profiler_utils.py'\n",
" adding 'line_profiler/autoprofile/profmod_extractor.py'\n",
" adding 'line_profiler/autoprofile/util_static.py'\n",
" adding 'line_profiler-4.1.2.dist-info/LICENSE.txt'\n",
" adding 'line_profiler-4.1.2.dist-info/LICENSE_Python.txt'\n",
" adding 'line_profiler-4.1.2.dist-info/METADATA'\n",
" adding 'line_profiler-4.1.2.dist-info/WHEEL'\n",
" adding 'line_profiler-4.1.2.dist-info/entry_points.txt'\n",
" adding 'line_profiler-4.1.2.dist-info/top_level.txt'\n",
" adding 'line_profiler-4.1.2.dist-info/RECORD'\n",
" removing build/bdist.linux-x86_64/wheel\n",
" Building wheel for line-profiler (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for line-profiler: filename=line_profiler-4.1.2-cp310-cp310-linux_x86_64.whl size=181416 sha256=bec7cdcaeef783ca0b0c9d8107039b378ffec744b6be457d3577110be1a40bed\n",
" Stored in directory: /tmp/pip-ephem-wheel-cache-bu_ugx9y/wheels/8b/53/99/b13c068bc9911e602585da8b2fbc756968a6f2fca8745a125a\n",
"Successfully built line-profiler\n",
"Installing collected packages: line-profiler\n",
" changing mode of /home/tmm1/micromamba/envs/jupyter/bin/kernprof to 755\n",
"Successfully installed line-profiler-4.1.2\n"
]
}
],
"source": [
"!pip uninstall --no-input -y line-profiler\n",
"!pip install -v --no-build-isolation git+https://github.com/tmm1/line_profiler@torch-memprof"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e31e30e7-802c-45e0-a59c-9ba2f9e910ec",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: transformers in /home/tmm1/.local/lib/python3.10/site-packages (4.35.2)\n",
"Requirement already satisfied: flash-attn in /home/tmm1/.local/lib/python3.10/site-packages (2.3.6)\n",
"Requirement already satisfied: filelock in /home/tmm1/.local/lib/python3.10/site-packages (from transformers) (3.13.1)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /home/tmm1/.local/lib/python3.10/site-packages (from transformers) (0.19.4)\n",
"Requirement already satisfied: numpy>=1.17 in /home/tmm1/.local/lib/python3.10/site-packages (from transformers) (1.26.2)\n",
"Requirement already satisfied: packaging>=20.0 in /home/tmm1/.local/lib/python3.10/site-packages (from transformers) (23.2)\n",
"Requirement already satisfied: pyyaml>=5.1 in /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages (from transformers) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /home/tmm1/.local/lib/python3.10/site-packages (from transformers) (2023.10.3)\n",
"Requirement already satisfied: requests in /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /home/tmm1/.local/lib/python3.10/site-packages (from transformers) (0.15.0)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /home/tmm1/.local/lib/python3.10/site-packages (from transformers) (0.4.1)\n",
"Requirement already satisfied: tqdm>=4.27 in /home/tmm1/.local/lib/python3.10/site-packages (from transformers) (4.66.1)\n",
"Requirement already satisfied: torch in /home/tmm1/.local/lib/python3.10/site-packages (from flash-attn) (2.1.1+cu118)\n",
"Requirement already satisfied: einops in /home/tmm1/.local/lib/python3.10/site-packages (from flash-attn) (0.7.0)\n",
"Requirement already satisfied: ninja in /home/tmm1/.local/lib/python3.10/site-packages (from flash-attn) (1.11.1.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /home/tmm1/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.12.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/tmm1/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.8.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages (from requests->transformers) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages (from requests->transformers) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages (from requests->transformers) (2.1.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages (from requests->transformers) (2023.11.17)\n",
"Requirement already satisfied: sympy in /home/tmm1/.local/lib/python3.10/site-packages (from torch->flash-attn) (1.12)\n",
"Requirement already satisfied: networkx in /home/tmm1/.local/lib/python3.10/site-packages (from torch->flash-attn) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages (from torch->flash-attn) (3.1.2)\n",
"Requirement already satisfied: triton==2.1.0 in /home/tmm1/.local/lib/python3.10/site-packages (from torch->flash-attn) (2.1.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /home/tmm1/micromamba/envs/jupyter/lib/python3.10/site-packages (from jinja2->torch->flash-attn) (2.1.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /home/tmm1/.local/lib/python3.10/site-packages (from sympy->torch->flash-attn) (1.3.0)\n"
]
}
],
"source": [
"!pip install -U transformers flash-attn"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a4a6548e-2937-4abc-a58f-5c69d98ad52d",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d2f5c1ab-cb25-4060-b139-9e212a47e853",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9dceecb8c52c4ba0bdecd86e897c36fe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"device = \"cuda\" # the device to load the model onto\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" \"mistralai/Mistral-7B-v0.1\",\n",
" use_flash_attention_2=True,\n",
" torch_dtype=torch.bfloat16\n",
").to(device)\n",
"tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-v0.1\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4ef8437c-1dec-43be-aac9-1aaef6dac75c",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Thu Dec 7 00:45:09 2023 \n",
"+---------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
"|-----------------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+======================+======================|\n",
"| 0 NVIDIA GeForce RTX 3090 On | 00000000:18:00.0 Off | N/A |\n",
"| 0% 46C P2 140W / 370W | 14849MiB / 24576MiB | 53% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
"| 1 NVIDIA GeForce RTX 3090 On | 00000000:29:00.0 Off | N/A |\n",
"| 0% 28C P8 22W / 420W | 8MiB / 24576MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
" \n",
"+---------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=======================================================================================|\n",
"| 0 N/A N/A 112619 C ...romamba/envs/jupyter/bin/python3.10 14836MiB |\n",
"+---------------------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7fa87249-b894-4dbc-b941-e241fbd9335c",
"metadata": {},
"outputs": [],
"source": [
"def run_prompt(model, prompt):\n",
" torch.cuda.empty_cache()\n",
" model_inputs = tokenizer([prompt], return_tensors=\"pt\").to(device)\n",
" generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)\n",
" return tokenizer.batch_decode(generated_ids)[0]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "bba4ebd3-54ee-4eee-b0d8-2382f2dc7604",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"My favourite condiment is\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "adae8f4c-a98b-4c58-8510-99445903e0ec",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"\"\"\n",
"Mistral AI, a well-funded artificial intelligence startup that launched five months ago, today released an open-source language model with 7 billion parameters.\n",
"\n",
"The model is called Mistral 7B in a nod to its parameter count. It’s available on GitHub under an Apache 2.0 license. According to the company, the model may be used for both research and commercial purposes.\n",
"\n",
"Paris-based Mistral AI was founded in May by former Meta Platforms Inc. and Google LLC researchers. Chief Executive Officer Arthur Mensch worked at the search giant’s DeepMind machine learning unit before launching the company. Chief Science Officer Guillaume Lample led the development of Meta’s open-source Llama language model.\n",
"\n",
"Four weeks after launching in May, Mistral AI closed a €105 million funding round at a €240 million valuation. The investment included contributions from Lightspeed Venture Partners, Index Ventures, Redpoint Ventures and more than a half dozen other backers. Mistral AI said at the time that it was planning to introduce its first language models in 2024.\n",
"\n",
"The release of the Mistral 7B language model today suggests that the development effort is advancing faster than expected. In a blog post, the company detailed that the model took three months to develop. In that time frame, Mistral AI’s founders assembled an engineering team and built a so-called MLOps stack, a collection of specialized software tools used for neural network development.\n",
"\n",
"The company says Mistral 7B can generate prose, summarize documents and perform other text processing tasks. It’s also capable of autocompleting software code written by developers. The model has a context length of 8k, which means that each prompt entered by users may contain up to 8,000 tokens.\n",
"\n",
"At the architectural level, Mistral AI features 7 billion parameters. Those are the configuration settings that determine how a neural network goes about processing data. The most advanced AI systems on the market today have hundreds of millions of such settings.\n",
"\n",
"The company claims that Mistral 7B “outperforms all currently available open models up to 13B parameters on all standard English and code benchmarks.” That includes the 13 billion parameter version of Llama 2, an advanced language model Meta released earlier this year. Moreover, Mistral 7B achieved performance “on par ” with a 34 billion parameter version of Meta’s Llama model, a predecessor to Llama 2.\n",
"\n",
"Mistral AI says its model can match the performance of larger neural networks while using less hardware. Lowering an AI’s hardware requirements not only decreases the cost of running it but also improves performance. As a result, the company sees Mistral 7B coming particularly for latency-sensitive use cases.\n",
"\n",
"Mistral 7B is the first in a series of large language models that the company plans to release. The upcoming additions to the lineup are expected to be better at reasoning tasks and support more languages. In the long term, Mistral AI also plans to offer hosted neural networks for the ENTERPRISES.\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0778482b-3f6f-4ddb-ab9b-9bb377e457e0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
]
},
{
"data": {
"text/plain": [
"'<s> \\nMistral AI, a well-funded artificial intelligence startup that launched five months ago, today released an open-source language model with 7 billion parameters.\\n\\nThe model is called Mistral 7B in a nod to its parameter count. It’s available on GitHub under an Apache 2.0 license. According to the company, the model may be used for both research and commercial purposes.\\n\\nParis-based Mistral AI was founded in May by former Meta Platforms Inc. and Google LLC researchers. Chief Executive Officer Arthur Mensch worked at the search giant’s DeepMind machine learning unit before launching the company. Chief Science Officer Guillaume Lample led the development of Meta’s open-source Llama language model.\\n\\nFour weeks after launching in May, Mistral AI closed a €105 million funding round at a €240 million valuation. The investment included contributions from Lightspeed Venture Partners, Index Ventures, Redpoint Ventures and more than a half dozen other backers. Mistral AI said at the time that it was planning to introduce its first language models in 2024.\\n\\nThe release of the Mistral 7B language model today suggests that the development effort is advancing faster than expected. In a blog post, the company detailed that the model took three months to develop. In that time frame, Mistral AI’s founders assembled an engineering team and built a so-called MLOps stack, a collection of specialized software tools used for neural network development.\\n\\nThe company says Mistral 7B can generate prose, summarize documents and perform other text processing tasks. It’s also capable of autocompleting software code written by developers. The model has a context length of 8k, which means that each prompt entered by users may contain up to 8,000 tokens.\\n\\nAt the architectural level, Mistral AI features 7 billion parameters. Those are the configuration settings that determine how a neural network goes about processing data. The most advanced AI systems on the market today have hundreds of millions of such settings.\\n\\nThe company claims that Mistral 7B “outperforms all currently available open models up to 13B parameters on all standard English and code benchmarks.” That includes the 13 billion parameter version of Llama 2, an advanced language model Meta released earlier this year. Moreover, Mistral 7B achieved performance “on par ” with a 34 billion parameter version of Meta’s Llama model, a predecessor to Llama 2.\\n\\nMistral AI says its model can match the performance of larger neural networks while using less hardware. Lowering an AI’s hardware requirements not only decreases the cost of running it but also improves performance. As a result, the company sees Mistral 7B coming particularly for latency-sensitive use cases.\\n\\nMistral 7B is the first in a series of large language models that the company plans to release. The upcoming additions to the lineup are expected to be better at reasoning tasks and support more languages. In the long term, Mistral AI also plans to offer hosted neural networks for the ENTERPRISES.\\n\\n—\\n\\nSource of this (above) article: © TechXplore News</s>'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"run_prompt(model, prompt)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "84ca755b-e287-46da-8ca2-09ff30929d50",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mem unit: 1e-06 bytes\n",
"\n",
"Total memory: 5441.72 MB\n",
"File: /home/tmm1/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py\n",
"Function: forward at line 313\n",
"\n",
"Line # Hits Mem Per Hit % Mem Line Contents\n",
"==============================================================\n",
" 313 def forward(\n",
" 314 self,\n",
" 315 hidden_states: torch.Tensor,\n",
" 316 attention_mask: Optional[torch.Tensor] = None,\n",
" 317 position_ids: Optional[torch.LongTensor] = None,\n",
" 318 past_key_value: Optional[Tuple[torch.Tensor]] = None,\n",
" 319 output_attentions: bool = False,\n",
" 320 use_cache: bool = False,\n",
" 321 **kwargs,\n",
" 322 ):\n",
" 323 352 if \"padding_mask\" in kwargs:\n",
" 324 warnings.warn(\n",
" 325 \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n",
" 326 )\n",
" 327 \n",
" 328 # overwrite attention_mask with padding_mask\n",
" 329 attention_mask = kwargs.pop(\"padding_mask\")\n",
" 330 352 bsz, q_len, _ = hidden_states.size()\n",
" 331 \n",
" 332 352 184.5 0.5 3.4 query_states = self.q_proj(hidden_states)\n",
" 333 352 46.1 0.1 0.8 key_states = self.k_proj(hidden_states)\n",
" 334 352 46.1 0.1 0.8 value_states = self.v_proj(hidden_states)\n",
" 335 \n",
" 336 352 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n",
" 337 352 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n",
" 338 352 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n",
" 339 \n",
" 340 352 kv_seq_len = key_states.shape[-2]\n",
" 341 352 if past_key_value is not None:\n",
" 342 320 kv_seq_len += past_key_value[0].shape[-2]\n",
" 343 \n",
" 344 # Because the input can be padded, the absolute sequence length depends on the max position id.\n",
" 345 352 rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1\n",
" 346 352 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)\n",
" 347 \n",
" 348 352 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n",
" 349 \n",
" 350 352 use_sliding_windows = (\n",
" 351 704 _flash_supports_window_size\n",
" 352 352 and hasattr(self.config, \"sliding_window\") is not None\n",
" 353 352 and kv_seq_len > self.config.sliding_window\n",
" 354 )\n",
" 355 \n",
" 356 352 if not _flash_supports_window_size:\n",
" 357 logger.warning_once(\n",
" 358 \"The current flash attention version does not support sliding window attention, for a more memory efficient implementation\"\n",
" 359 \" make sure to upgrade flash-attn library.\"\n",
" 360 )\n",
" 361 \n",
" 362 352 if past_key_value is not None:\n",
" 363 # Activate slicing cache only if the config has a value `sliding_windows` attribute\n",
" 364 320 if hasattr(self.config, \"sliding_window\") and kv_seq_len > self.config.sliding_window:\n",
" 365 slicing_tokens = kv_seq_len - self.config.sliding_window\n",
" 366 \n",
" 367 past_key = past_key_value[0]\n",
" 368 past_value = past_key_value[1]\n",
" 369 \n",
" 370 past_key = past_key[:, :, slicing_tokens:, :].contiguous()\n",
" 371 past_value = past_value[:, :, slicing_tokens:, :].contiguous()\n",
" 372 \n",
" 373 if past_key.shape[-2] != self.config.sliding_window - 1:\n",
" 374 raise ValueError(\n",
" 375 f\"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got\"\n",
" 376 f\" {past_key.shape}\"\n",
" 377 )\n",
" 378 \n",
" 379 past_key_value = (past_key, past_value)\n",
" 380 \n",
" 381 if attention_mask is not None:\n",
" 382 attention_mask = attention_mask[:, slicing_tokens:]\n",
" 383 attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)\n",
" 384 \n",
" 385 320 457.8 1.4 8.4 key_states = torch.cat([past_key_value[0], key_states], dim=2)\n",
" 386 320 457.8 1.4 8.4 value_states = torch.cat([past_key_value[1], value_states], dim=2)\n",
" 387 \n",
" 388 352 past_key_value = (key_states, value_states) if use_cache else None\n",
" 389 \n",
" 390 # repeat k/v heads if n_kv_heads < n_heads\n",
" 391 352 2015.6 5.7 37.0 key_states = repeat_kv(key_states, self.num_key_value_groups)\n",
" 392 352 2049.2 5.8 37.7 value_states = repeat_kv(value_states, self.num_key_value_groups)\n",
" 393 \n",
" 394 # TODO: Mistral does not have dropout in the config??\n",
" 395 # It is recommended to use dropout with FA according to the docs\n",
" 396 # when training.\n",
" 397 352 dropout_rate = 0.0 # if not self.training else self.attn_dropout\n",
" 398 \n",
" 399 # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n",
" 400 # therefore the input hidden states gets silently casted in float32. Hence, we need\n",
" 401 # cast them back in float16 just to be sure everything works as expected.\n",
" 402 352 input_dtype = query_states.dtype\n",
" 403 352 if input_dtype == torch.float32:\n",
" 404 # Handle the case where the model is quantized\n",
" 405 if hasattr(self.config, \"_pre_quantization_dtype\"):\n",
" 406 target_dtype = self.config._pre_quantization_dtype\n",
" 407 else:\n",
" 408 target_dtype = self.q_proj.weight.dtype\n",
" 409 \n",
" 410 logger.warning_once(\n",
" 411 f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n",
" 412 f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n",
" 413 f\" {target_dtype}.\"\n",
" 414 )\n",
" 415 \n",
" 416 query_states = query_states.to(target_dtype)\n",
" 417 key_states = key_states.to(target_dtype)\n",
" 418 value_states = value_states.to(target_dtype)\n",
" 419 \n",
" 420 # Reashape to the expected shape for Flash Attention\n",
" 421 352 query_states = query_states.transpose(1, 2)\n",
" 422 352 key_states = key_states.transpose(1, 2)\n",
" 423 352 value_states = value_states.transpose(1, 2)\n",
" 424 \n",
" 425 704 184.5 0.3 3.4 attn_output = self._flash_attention_forward(\n",
" 426 352 query_states,\n",
" 427 352 key_states,\n",
" 428 352 value_states,\n",
" 429 352 attention_mask,\n",
" 430 352 q_len,\n",
" 431 352 dropout=dropout_rate,\n",
" 432 352 use_sliding_windows=use_sliding_windows,\n",
" 433 )\n",
" 434 \n",
" 435 352 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()\n",
" 436 352 attn_output = self.o_proj(attn_output)\n",
" 437 \n",
" 438 352 if not output_attentions:\n",
" 439 352 attn_weights = None\n",
" 440 \n",
" 441 352 return attn_output, attn_weights, past_key_value\n",
"\n",
"Total memory: 0.017408 MB\n",
"File: /tmp/ipykernel_112619/3258697482.py\n",
"Function: run_prompt at line 1\n",
"\n",
"Line # Hits Mem Per Hit % Mem Line Contents\n",
"==============================================================\n",
"\n"
]
}
],
"source": [
"import transformers\n",
"from line_profiler import LineProfiler\n",
"instrumented = [\n",
" transformers.models.mistral.modeling_mistral.MistralAttention.forward,\n",
" transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward,\n",
"]\n",
"lp = LineProfiler(*instrumented)\n",
"lp(run_prompt)(model, prompt)\n",
"lp.print_stats()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c55a26e0-7795-429a-98f7-a8d5e4e3f310",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2ec02baaa31840eb9e4fdd805739b188",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mem unit: 1e-06 bytes\n",
"\n",
"Total memory: 24370.2 MB\n",
"File: /home/tmm1/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py\n",
"Function: forward at line 228\n",
"\n",
"Line # Hits Mem Per Hit % Mem Line Contents\n",
"==============================================================\n",
" 228 def forward(\n",
" 229 self,\n",
" 230 hidden_states: torch.Tensor,\n",
" 231 attention_mask: Optional[torch.Tensor] = None,\n",
" 232 position_ids: Optional[torch.LongTensor] = None,\n",
" 233 past_key_value: Optional[Tuple[torch.Tensor]] = None,\n",
" 234 output_attentions: bool = False,\n",
" 235 use_cache: bool = False,\n",
" 236 **kwargs,\n",
" 237 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n",
" 238 1568 if \"padding_mask\" in kwargs:\n",
" 239 warnings.warn(\n",
" 240 \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n",
" 241 )\n",
" 242 1568 bsz, q_len, _ = hidden_states.size()\n",
" 243 \n",
" 244 1568 194.5 0.1 0.8 query_states = self.q_proj(hidden_states)\n",
" 245 1568 48.6 0.0 0.2 key_states = self.k_proj(hidden_states)\n",
" 246 1568 48.6 0.0 0.2 value_states = self.v_proj(hidden_states)\n",
" 247 \n",
" 248 1568 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n",
" 249 1568 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n",
" 250 1568 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n",
" 251 \n",
" 252 1568 kv_seq_len = key_states.shape[-2]\n",
" 253 1568 if past_key_value is not None:\n",
" 254 1536 kv_seq_len += past_key_value[0].shape[-2]\n",
" 255 1568 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n",
" 256 1568 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n",
" 257 \n",
" 258 1568 if past_key_value is not None:\n",
" 259 # reuse k, v, self_attention\n",
" 260 1536 2257.1 1.5 9.3 key_states = torch.cat([past_key_value[0], key_states], dim=2)\n",
" 261 1536 2261.2 1.5 9.3 value_states = torch.cat([past_key_value[1], value_states], dim=2)\n",
" 262 \n",
" 263 1568 past_key_value = (key_states, value_states) if use_cache else None\n",
" 264 \n",
" 265 # repeat k/v heads if n_kv_heads < n_heads\n",
" 266 1568 9222.8 5.9 37.8 key_states = repeat_kv(key_states, self.num_key_value_groups)\n",
" 267 1568 10142.9 6.5 41.6 value_states = repeat_kv(value_states, self.num_key_value_groups)\n",
" 268 \n",
" 269 1568 1077.6 0.7 4.4 attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n",
" 270 \n",
" 271 1568 if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n",
" 272 raise ValueError(\n",
" 273 f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n",
" 274 f\" {attn_weights.size()}\"\n",
" 275 )\n",
" 276 \n",
" 277 1568 if attention_mask is not None:\n",
" 278 1568 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n",
" 279 raise ValueError(\n",
" 280 f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n",
" 281 )\n",
" 282 \n",
" 283 1568 attn_weights = attn_weights + attention_mask\n",
" 284 \n",
" 285 # upcast attention to fp32\n",
" 286 1568 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n",
" 287 1568 194.5 0.1 0.8 attn_output = torch.matmul(attn_weights, value_states)\n",
" 288 \n",
" 289 1568 if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n",
" 290 raise ValueError(\n",
" 291 f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n",
" 292 f\" {attn_output.size()}\"\n",
" 293 )\n",
" 294 \n",
" 295 1568 attn_output = attn_output.transpose(1, 2).contiguous()\n",
" 296 1568 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n",
" 297 \n",
" 298 1568 attn_output = self.o_proj(attn_output)\n",
" 299 \n",
" 300 1568 if not output_attentions:\n",
" 301 1568 -1077.6 -0.7 -4.4 attn_weights = None\n",
" 302 \n",
" 303 1568 return attn_output, attn_weights, past_key_value\n",
"\n",
"Total memory: 0.017408 MB\n",
"File: /tmp/ipykernel_112619/3258697482.py\n",
"Function: run_prompt at line 1\n",
"\n",
"Line # Hits Mem Per Hit % Mem Line Contents\n",
"==============================================================\n",
"\n"
]
}
],
"source": [
"del model\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" \"mistralai/Mistral-7B-v0.1\",\n",
" torch_dtype=torch.bfloat16\n",
").to(device)\n",
"\n",
"instrumented = [\n",
" transformers.models.mistral.modeling_mistral.MistralAttention.forward,\n",
" transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward,\n",
"]\n",
"lp = LineProfiler(*instrumented)\n",
"lp(run_prompt)(model, prompt)\n",
"lp.print_stats()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment