Skip to content

Instantly share code, notes, and snippets.

@huseinzol05
Last active May 6, 2024 17:03
Show Gist options
  • Save huseinzol05/4d47ddf026b7c37fcbe02a6afa13205a to your computer and use it in GitHub Desktop.
Save huseinzol05/4d47ddf026b7c37fcbe02a6afa13205a to your computer and use it in GitHub Desktop.
decoder steps using HQQ
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "5c119899",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9b207853",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/husein/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"import torch\n",
"from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline\n",
"\n",
"model_id = \"mesolitica/malaysian-whisper-small\"\n",
"compute_dtype = torch.bfloat16 # please don't change this\n",
"device = \"cuda:0\"\n",
"\n",
"model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=compute_dtype) \n",
"processor = AutoProcessor.from_pretrained(model_id)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "3d33aa41",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation=\"flash_attention_2\"` instead.\n",
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n"
]
}
],
"source": [
"model_flash = AutoModelForSpeechSeq2Seq.from_pretrained(\n",
" model_id, torch_dtype=torch.float16, use_flash_attention_2 = True) \n",
"_ = model_flash.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "eee81e05",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 378.09it/s]\n",
"100%|████████████████████████████████████████████████████████████████████████████████| 72/72 [00:00<00:00, 135.91it/s]\n",
"100%|███████████████████████████████████████████████████████████████████████████████| 51/51 [00:00<00:00, 5030.92it/s]\n",
"100%|██████████████████████████████████████████████████████████████████████████████| 120/120 [00:00<00:00, 326.49it/s]\n"
]
},
{
"data": {
"text/plain": [
"WhisperDecoder(\n",
" (embed_tokens): Embedding(51865, 768, padding_idx=50257)\n",
" (embed_positions): WhisperPositionalEmbedding(448, 768)\n",
" (layers): ModuleList(\n",
" (0-11): 12 x WhisperDecoderLayer(\n",
" (self_attn): WhisperSdpaAttention(\n",
" (k_proj): HQQLinear()\n",
" (v_proj): HQQLinear()\n",
" (q_proj): HQQLinear()\n",
" (out_proj): HQQLinear()\n",
" )\n",
" (activation_fn): GELUActivation()\n",
" (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): WhisperSdpaAttention(\n",
" (k_proj): Linear(in_features=768, out_features=768, bias=False)\n",
" (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): HQQLinear()\n",
" (fc2): HQQLinear()\n",
" (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
")"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from hqq.models.hf.base import AutoHQQHFModel\n",
"from hqq.core.quantize import *\n",
"\n",
"# Please keep nbits=4 and axis=1\n",
"quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1) \n",
"HQQLinear.set_backend(HQQBackend.PYTORCH)\n",
"\n",
"AutoHQQHFModel.quantize_model(model.model.encoder, quant_config=quant_config, compute_dtype=compute_dtype, device=device)\n",
"AutoHQQHFModel.quantize_model(model.model.decoder, quant_config=quant_config, compute_dtype=compute_dtype, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "33d23a7a",
"metadata": {},
"outputs": [],
"source": [
"import hqq.models.base as hqq_base\n",
"hqq_base._QUANT_LAYERS = [torch.nn.Linear, HQQLinear]\n",
"\n",
"from hqq.utils.patching import prepare_for_inference\n",
"\n",
"AutoHQQHFModel.set_auto_linear_tags(model.model.encoder)\n",
"prepare_for_inference(model.model.encoder)\n",
"\n",
"AutoHQQHFModel.set_auto_linear_tags(model.model.decoder)\n",
"prepare_for_inference(model.model.decoder, backend=\"torchao_int4\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "b9ebb9e7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
"It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
]
}
],
"source": [
"import requests\n",
"from datasets import Audio\n",
"from transformers import AutoProcessor\n",
"\n",
"sr = 16000\n",
"audio = Audio(sampling_rate=sr)\n",
"r = requests.get('https://huggingface.co/datasets/huseinzol05/malaya-speech-stt-test-set/resolve/main/test.mp3')\n",
"processor = AutoProcessor.from_pretrained('mesolitica/malaysian-whisper-small')\n",
"y = audio.decode_example(audio.encode_example(r.content))['array']\n",
"inputs = processor([y], return_tensors = 'pt').to('cuda')\n",
"inputs['input_features'] = inputs['input_features'].type(torch.bfloat16)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "0d827bcc",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 44.6 ms, sys: 2.41 ms, total: 47 ms\n",
"Wall time: 5.62 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"with torch.no_grad():\n",
" out_encoder = model.model.encoder(inputs['input_features'])"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "4129c4a6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[50258, 50282, 50359]], device='cuda:0')"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels = processor.tokenizer(\n",
" '<|startoftranscript|><|ms|><|transcribe|>', \n",
" add_special_tokens = False, \n",
" return_tensors = 'pt'\n",
").to('cuda')['input_ids']\n",
"labels"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "12a2e190",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.48621106147766113"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"before = time.time()\n",
"\n",
"for _ in range(1024):\n",
" with torch.no_grad():\n",
" out_decoder = model.model.decoder(labels, encoder_hidden_states=out_encoder[0])\n",
" proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n",
" labels = torch.concat([labels, proj], dim = 1)\n",
" if proj == model.config.eos_token_id:\n",
" break\n",
" \n",
"after_hqq = time.time() - before\n",
"after_hqq"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "1ded45f4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'<|startoftranscript|><|ms|><|transcribe|> Assembly on Aging di Vienna, Australia yang telah diadakan pada tahun 1982 dan berasaskan unjuran tersebut maka Jabatan Perangkaan Malaysia menganggarkan menjelang tahun 2035 sejumlah 15% penduduk kita adalah daripada kalangan warga emas. Untuk makluman Tuan Yang Pertua dan juga Aliam Bohmat, pembangunan sistem pendaftaran warga emas ataupun kita sebutkan IWEN adalah usaha kerajaan ke arah merealisasikan objektif yang telah<|endoftext|>'"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"processor.tokenizer.decode(labels[0])"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "8343bbac",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
]
}
],
"source": [
"inputs = processor([y], return_tensors = 'pt').to('cuda')\n",
"inputs['input_features'] = inputs['input_features'].type(torch.float16)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "e35c5c21",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 302 ms, sys: 0 ns, total: 302 ms\n",
"Wall time: 301 ms\n"
]
},
{
"data": {
"text/plain": [
"0.30090904235839844"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"\n",
"before = time.time()\n",
"r = model_flash.generate(inputs['input_features'], language='ms', return_timestamps=True, max_length = 1024)\n",
"after_flash = time.time() - before\n",
"after_flash"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "8711fc0f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"282.4773869665257"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(r[0]) / after_flash # speed for flash"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "0a58cf39",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"296.1676757463406"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(labels[0]) / after_hqq # speed for hqq"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02ed610c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python3.10",
"language": "python",
"name": "python3.10"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment