Last active May 6, 2024
decoder steps using HQQ
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
"outputs": [
"source": [
"import torch\n",
"from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline\n",
"model_id = \"mesolitica/malaysian-whisper-small\"\n",
"compute_dtype = torch.bfloat16 # please don't change this\n",
"device = \"cuda:0\"\n",
"model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=compute_dtype) \n",
"processor = AutoProcessor.from_pretrained(model_id)"
"outputs": [
"The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation=\"flash_attention_2\"` instead.\n",
"You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `'cuda')`.\n"
"source": [
"model_flash = AutoModelForSpeechSeq2Seq.from_pretrained(\n",
" model_id, torch_dtype=torch.float16, use_flash_attention_2 = True) \n",
"_ = model_flash.cuda()"
"outputs": [
" (embed_tokens): Embedding(51865, 768, padding_idx=50257)\n",
" (embed_positions): WhisperPositionalEmbedding(448, 768)\n",
" (layers): ModuleList(\n",
" (0-11): 12 x WhisperDecoderLayer(\n",
" (self_attn): WhisperSdpaAttention(\n",
" (k_proj): HQQLinear()\n",
" (v_proj): HQQLinear()\n",
" (q_proj): HQQLinear()\n",
" (out_proj): HQQLinear()\n",
" )\n",
" (activation_fn): GELUActivation()\n",
" (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (encoder_attn): WhisperSdpaAttention(\n",
" (k_proj): Linear(in_features=768, out_features=768, bias=False)\n",
" (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" )\n",
" (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (fc1): HQQLinear()\n",
" (fc2): HQQLinear()\n",
" (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
"source": [
"from hqq.models.hf.base import AutoHQQHFModel\n",
"from hqq.core.quantize import *\n",
"# Please keep nbits=4 and axis=1\n",
"quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1) \n",
"AutoHQQHFModel.quantize_model(model.model.encoder, quant_config=quant_config, compute_dtype=compute_dtype, device=device)\n",
"AutoHQQHFModel.quantize_model(model.model.decoder, quant_config=quant_config, compute_dtype=compute_dtype, device=device)"
"source": [
"import hqq.models.base as hqq_base\n",
"hqq_base._QUANT_LAYERS = [torch.nn.Linear, HQQLinear]\n",
"from hqq.utils.patching import prepare_for_inference\n",
"prepare_for_inference(model.model.decoder, backend=\"torchao_int4\")"
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
"It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n"
"source": [
"import requests\n",
"from datasets import Audio\n",
"from transformers import AutoProcessor\n",
"sr = 16000\n",
"audio = Audio(sampling_rate=sr)\n",
"r = requests.get('')\n",
"processor = AutoProcessor.from_pretrained('mesolitica/malaysian-whisper-small')\n",
"y = audio.decode_example(audio.encode_example(r.content))['array']\n",
"inputs = processor([y], return_tensors = 'pt').to('cuda')\n",
"inputs['input_features'] = inputs['input_features'].type(torch.bfloat16)"
"outputs": [
CPU times: user 44.6 ms, sys: 2.41 ms, total: 47 ms
Wall time: 5.62 ms
"Wall time: 5.62 ms\n"
"source": [
"with torch.no_grad():\n",
" out_encoder = model.model.encoder(inputs['input_features'])"
tensor([[50258, 50282, 50359]], device='cuda:0')
"source": [
"labels = processor.tokenizer(\n",
" '<|startoftranscript|><|ms|><|transcribe|>', \n",
" add_special_tokens = False, \n",
" return_tensors = 'pt'\n",
"source": [
"before = time.time()\n",
"for _ in range(1024):\n",
" with torch.no_grad():\n",
" out_decoder = model.model.decoder(labels, encoder_hidden_states=out_encoder[0])\n",
" proj = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(-1)\n",
" labels = torch.concat([labels, proj], dim = 1)\n",
" if proj == model.config.eos_token_id:\n",
" break\n",
" \n",
"after_hqq = time.time() - before\n",
'<|startoftranscript|><|ms|><|transcribe|> Assembly on Aging di Vienna, Australia yang telah diadakan pada tahun 1982 dan berasaskan unjuran tersebut maka Jabatan Perangkaan Malaysia menganggarkan menjelang tahun 2035 sejumlah 15% penduduk kita adalah daripada kalangan warga emas. Untuk makluman Tuan Yang Pertua dan juga Aliam Bohmat, pembangunan sistem pendaftaran warga emas ataupun kita sebutkan IWEN adalah usaha kerajaan ke arah merealisasikan objektif yang telah<|endoftext|>'
"source": [
"source": [
"inputs = processor([y], return_tensors = 'pt').to('cuda')\n",
"inputs['input_features'] = inputs['input_features'].type(torch.float16)"
CPU times: user 302 ms, sys: 0 ns, total: 302 ms
Wall time: 301 ms
"Wall time: 301 ms\n"
"source": [
"before = time.time()\n",
"r = model_flash.generate(inputs['input_features'], language='ms', return_timestamps=True, max_length = 1024)\n",
"after_flash = time.time() - before\n",
"source": [
len(r[0]) / after_flash # speed for flash
"source": [
len(labels[0]) / after_hqq # speed for hqq
