Skip to content

Instantly share code, notes, and snippets.

@huseinzol05
Created May 11, 2024 14:12
Show Gist options
  • Save huseinzol05/9aff34ec1427ee8c92240cb4f3cc0c88 to your computer and use it in GitHub Desktop.
Save huseinzol05/9aff34ec1427ee8c92240cb4f3cc0c88 to your computer and use it in GitHub Desktop.
example of whisper static cache
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "eef3ff4e",
"metadata": {},
"outputs": [],
"source": [
"# !pip3.10 install huggingface-hub==0.23\n",
"# !pip3.10 install git+https://github.com/mesolitica/whisper-static-cache\n",
"# !pip3.10 uninstall torch -y; pip3.10 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "bb0ff692",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f35b7d2a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/husein/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline\n",
"from transformers.cache_utils import WhisperStaticCache\n",
"import torch\n",
"import requests\n",
"from datasets import Audio\n",
"from transformers import AutoProcessor\n",
"from tqdm import tqdm\n",
"\n",
"sr = 16000\n",
"audio = Audio(sampling_rate=sr)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "faaa46ca",
"metadata": {},
"outputs": [],
"source": [
"model_id = \"openai/whisper-large-v3\"\n",
"compute_dtype = torch.bfloat16\n",
"device = \"cuda:0\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "70367dfe",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Instantiating WhisperSdpaAttention without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.\n",
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype = compute_dtype) \n",
"processor = AutoProcessor.from_pretrained(model_id)\n",
"_ = model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0a077da5",
"metadata": {},
"outputs": [],
"source": [
"model_normal = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype = compute_dtype) \n",
"_ = model_normal.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "deb50faa",
"metadata": {},
"outputs": [],
"source": [
"model.model.encoder.forward = torch.compile(model.model.encoder.forward, mode='reduce-overhead', fullgraph=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "bd05399d",
"metadata": {},
"outputs": [],
"source": [
"def decode_one_tokens(\n",
" model, \n",
" proj_out, \n",
" cur_token, \n",
" past_key_values, \n",
" position_ids, \n",
" cache_position, \n",
" out_encoder,\n",
"):\n",
" \n",
" out_decoder = model(\n",
" cur_token, \n",
" encoder_hidden_states=out_encoder,\n",
" past_key_values = past_key_values,\n",
" position_ids=position_ids,\n",
" use_cache = True,\n",
" return_dict = False,\n",
" cache_position = cache_position\n",
" )\n",
" new_token = torch.argmax(proj_out(out_decoder[0][:,-1:]), dim=-1)\n",
" return new_token"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3ea5b98d",
"metadata": {},
"outputs": [],
"source": [
"decode_one_tokens = torch.compile(decode_one_tokens, mode=\"reduce-overhead\", fullgraph=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "6fa3307b",
"metadata": {},
"outputs": [],
"source": [
"r = requests.get('https://huggingface.co/datasets/huseinzol05/malaya-speech-stt-test-set/resolve/main/test.mp3')\n",
"y = audio.decode_example(audio.encode_example(r.content))['array']\n",
"r = requests.get('https://github.com/mesolitica/malaya-speech/raw/master/speech/singlish/singlish0.wav')\n",
"y2 = audio.decode_example(audio.encode_example(r.content))['array']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9b20029a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
]
}
],
"source": [
"inputs = processor([y], return_tensors = 'pt').to('cuda')\n",
"inputs['input_features'] = inputs['input_features'].type(torch.bfloat16)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "be3af0b9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
]
}
],
"source": [
"inputs2 = processor([y2], return_tensors = 'pt').to('cuda')\n",
"inputs2['input_features'] = inputs2['input_features'].type(torch.bfloat16)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "96e732a6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/husein/.local/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:393: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.\n",
" warnings.warn(\n",
"/home/husein/.local/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:393: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.\n",
" warnings.warn(\n",
"/home/husein/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2176: UserWarning: Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() before each model invocation\n",
" warnings.warn(\n"
]
}
],
"source": [
"# warming up\n",
"\n",
"for _ in range(3):\n",
" with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n",
" out_encoder = model.model.encoder(inputs['input_features'])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "50cc3d9f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 38.2 ms, sys: 324 µs, total: 38.5 ms\n",
"Wall time: 38 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n",
" out_encoder = model.model.encoder(inputs['input_features'])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "89edf30b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 44.1 ms, sys: 4.01 ms, total: 48.1 ms\n",
"Wall time: 47.4 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n",
" out_encoder = model_normal.model.encoder(inputs['input_features'])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "20ff4a3f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/husein/.local/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py:2176: UserWarning: Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() before each model invocation\n",
" warnings.warn(\n"
]
}
],
"source": [
"# warming up\n",
"\n",
"with torch.no_grad():\n",
" language = 'en'\n",
" initial_strings = [\n",
" '<|startoftranscript|>',\n",
" f'<|{language}|>',\n",
" '<|transcribe|>'\n",
" ]\n",
"\n",
" labels = processor.tokenizer(\n",
" ''.join(initial_strings), \n",
" add_special_tokens = False,\n",
" return_tensors = 'pt',\n",
" ).to('cuda')['input_ids']\n",
" out_decoder = model.model.decoder(\n",
" labels, \n",
" encoder_hidden_states=out_encoder[0],\n",
" past_key_values = None,\n",
" use_cache = True\n",
" )\n",
" past_key_values = out_decoder.past_key_values\n",
" proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n",
" out_encoder = out_encoder[0].clone()\n",
" \n",
" cache = WhisperStaticCache(model.config, compute_dtype, device, past_key_values)\n",
" seq_length = past_key_values[0][0].shape[2]\n",
" cache_position = torch.tensor([seq_length], device=device)\n",
" position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)\n",
" \n",
" for i in range(model.config.max_target_positions - len(initial_strings)):\n",
" proj = decode_one_tokens(\n",
" model.model.decoder, \n",
" model.proj_out, \n",
" proj.clone(), \n",
" cache, \n",
" position_ids,\n",
" cache_position, \n",
" out_encoder\n",
" )\n",
" labels = torch.concat([labels, proj], axis = -1)\n",
" position_ids += 1\n",
" cache_position += 1\n",
"\n",
" if proj == model.config.eos_token_id:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "f6724dfa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 50.5 ms, sys: 56 µs, total: 50.5 ms\n",
"Wall time: 49.7 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"with torch.no_grad():\n",
" \n",
" with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n",
" out_encoder = model.model.encoder(inputs['input_features'])\n",
" \n",
" language = 'en'\n",
" initial_strings = [\n",
" '<|startoftranscript|>',\n",
" f'<|{language}|>',\n",
" '<|transcribe|>'\n",
" ]\n",
"\n",
" labels = processor.tokenizer(\n",
" ''.join(initial_strings), \n",
" add_special_tokens = False,\n",
" return_tensors = 'pt',\n",
" ).to('cuda')['input_ids']\n",
" out_decoder = model.model.decoder(\n",
" labels, \n",
" encoder_hidden_states=out_encoder[0],\n",
" past_key_values = None,\n",
" use_cache = True\n",
" )\n",
" past_key_values = out_decoder.past_key_values\n",
" proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n",
" out_encoder = out_encoder[0].clone()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "74ded2b3",
"metadata": {},
"outputs": [],
"source": [
"cache.reset(existing_cache = past_key_values)\n",
"seq_length = past_key_values[0][0].shape[2]\n",
"cache_position = torch.tensor([seq_length], device=device)\n",
"position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "4014b7fb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 16%|████████████▉ | 73/445 [00:00<00:01, 186.26it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 396 ms, sys: 3.8 ms, total: 400 ms\n",
"Wall time: 398 ms\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"%%time\n",
"\n",
"with torch.no_grad():\n",
" for i in tqdm(range(model.config.max_target_positions - len(initial_strings))):\n",
" proj = decode_one_tokens(\n",
" model.model.decoder, \n",
" model.proj_out, \n",
" proj.clone(), \n",
" cache, \n",
" position_ids,\n",
" cache_position, \n",
" out_encoder\n",
" )\n",
" labels = torch.concat([labels, proj], axis = -1)\n",
" position_ids += 1\n",
" cache_position += 1\n",
"\n",
" if proj == model.config.eos_token_id:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "62dc7314",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'<|startoftranscript|><|en|><|transcribe|> Assembly on Aging in Vienna, Australia, which was held in 1982. Based on the evidence, the Ministry of Industry of Malaysia estimated that by 2035, 15% of our population will be from the gold community. For your information, Mr. President and Mr. President, the development of the gold community registration system,<|endoftext|>'"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"processor.tokenizer.decode(labels[0])"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "5d7cc231",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 54.5 ms, sys: 97 µs, total: 54.6 ms\n",
"Wall time: 54.3 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"with torch.no_grad():\n",
" \n",
" with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):\n",
" out_encoder = model_normal.model.encoder(inputs['input_features'])\n",
" \n",
" language = 'en'\n",
" initial_strings = [\n",
" '<|startoftranscript|>',\n",
" f'<|{language}|>',\n",
" '<|transcribe|>'\n",
" ]\n",
"\n",
" labels = processor.tokenizer(\n",
" ''.join(initial_strings), \n",
" add_special_tokens = False,\n",
" return_tensors = 'pt',\n",
" ).to('cuda')['input_ids']\n",
" out_decoder = model.model.decoder(\n",
" labels, \n",
" encoder_hidden_states=out_encoder[0],\n",
" past_key_values = None,\n",
" use_cache = True\n",
" )\n",
" past_key_values = out_decoder.past_key_values\n",
" proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n",
" out_encoder = out_encoder[0].clone()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "202bc451",
"metadata": {},
"outputs": [],
"source": [
"cache_normal = WhisperStaticCache(model.config, compute_dtype, device, past_key_values)\n",
"seq_length = past_key_values[0][0].shape[2]\n",
"cache_position = torch.tensor([seq_length], device=device)\n",
"position_ids = torch.arange(seq_length, seq_length + proj.shape[1], device = device)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "b783a51f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 16%|████████████▉ | 73/445 [00:00<00:02, 150.20it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 491 ms, sys: 0 ns, total: 491 ms\n",
"Wall time: 490 ms\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"%%time\n",
"\n",
"with torch.no_grad():\n",
" for i in tqdm(range(model_normal.config.max_target_positions - len(initial_strings))):\n",
" out_decoder = model_normal.model.decoder(\n",
" proj, \n",
" encoder_hidden_states=out_encoder,\n",
" past_key_values = cache_normal,\n",
" position_ids=position_ids,\n",
" use_cache = True,\n",
" return_dict = False,\n",
" cache_position = cache_position\n",
" )\n",
" proj = torch.argmax(model_normal.proj_out(out_decoder[0][:,-1:]), dim=-1)\n",
" labels = torch.concat([labels, proj], axis = -1)\n",
" position_ids += 1\n",
" cache_position += 1\n",
"\n",
" if proj == model.config.eos_token_id:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "19456d76",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'<|startoftranscript|><|en|><|transcribe|> Assembly on Aging in Vienna, Australia, which was held in 1982. Based on the evidence, the Ministry of Industry of Malaysia estimated that by 2035, 15% of our population will be from the gold community. For your information, Mr. President and Mr. President, the development of the gold community registration system,<|endoftext|>'"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"processor.tokenizer.decode(labels[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0742e47d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python3.10",
"language": "python",
"name": "python3.10"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@Omer-ler
Copy link

Omer-ler commented Jun 4, 2024

Very nice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment