Skip to content

Instantly share code, notes, and snippets.

@attila-dusnoki-htec
Last active November 28, 2023 16:11
Show Gist options
  • Save attila-dusnoki-htec/1bf961f975df4258ada9478fc4aa707f to your computer and use it in GitHub Desktop.
Save attila-dusnoki-htec/1bf961f975df4258ada9478fc4aa707f to your computer and use it in GitHub Desktop.
whisper with mgx (wip)
from transformers import WhisperProcessor, WhisperTokenizer
from datasets import load_dataset
import migraphx as mgx
import os
import numpy as np
from tqdm.auto import tqdm
# load model, tokenizer and processor
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
if os.path.isfile("models/whisper_with_attn_mask_onnx/encoder_model.mxr"):
encoder_model = mgx.load("models/whisper_with_attn_mask_onnx/encoder_model.mxr", format="msgpack")
else:
encoder_model_param_shapes = {"input_features": [1, 80, 3000]}
encoder_model = mgx.parse_onnx("models/whisper_with_attn_mask_onnx/encoder_model.onnx",
map_input_dims=encoder_model_param_shapes)
encoder_model.compile(mgx.get_target("gpu"))
mgx.save(encoder_model, "models/whisper_with_attn_mask_onnx/encoder_model.mxr", format="msgpack")
if os.path.isfile("models/whisper_with_attn_mask_onnx/decoder_model.mxr"):
decoder_model = mgx.load("models/whisper_with_attn_mask_onnx/decoder_model.mxr", format="msgpack")
else:
decoder_model_param_shapes = {"decoder_input_ids": [1, 448], "decoder_attention_mask": [1, 448], "encoder_hidden_states": [1, 1500, 384]}
decoder_model = mgx.parse_onnx("models/whisper_with_attn_mask_onnx/decoder_model.onnx",
map_input_dims=decoder_model_param_shapes)
decoder_model.compile(mgx.get_target("gpu"))
mgx.save(decoder_model, "models/whisper_with_attn_mask_onnx/decoder_model.mxr", format="msgpack")
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
print(f"input_features=... shape={input_features.shape}")
decoder_start_token_id = 50257 # <|startoftranscript|>
pad_token_id = 50256 # "<|endoftext|>"
task = 50358 # "<|transcribe|>"
notimestamps = 50362 # <|notimestamps|>
language = 50258 # <|en|>
max_length = 448
sot = [decoder_start_token_id, language, task, notimestamps]
decoder_input_ids = np.array([sot + [pad_token_id] * (max_length - len(sot))])
# 0 masked | 1 un-masked
decoder_attention_mask = np.array([[1] * len(sot) + [0] * (max_length - len(sot))])
# generate token ids
result = encoder_model.run(
{"input_features":
input_features.detach().cpu().numpy().astype(np.float32)})
states = np.array(result[0])
token_len = max_length
for t in tqdm(range(len(sot) - 1, max_length)):
decoder_attention_mask[0][t] = 1
result = np.array(decoder_model.run(
{"decoder_input_ids":
decoder_input_ids.astype(np.int64),
"decoder_attention_mask":
decoder_attention_mask.astype(np.int64),
"encoder_hidden_states": states.astype(np.float32)})[0])
# result.shape = [1,max_length,51864]
new_token = np.argmax(result[0][t])
if new_token == pad_token_id:
token_len = t
break
decoder_input_ids[0][t+1] = new_token
decoder_input_ids = decoder_input_ids[:, :token_len]
print(decoder_input_ids.shape)
transcription = processor.batch_decode(decoder_input_ids, skip_special_tokens=False)
# # ['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|endoftext|>']
print(f"transcription={transcription}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment