Skip to content

Instantly share code, notes, and snippets.

@moomou
Created January 6, 2024 04:48
Show Gist options
  • Save moomou/7df8345d79a0063d67d1fa2b4cf55db8 to your computer and use it in GitHub Desktop.
Save moomou/7df8345d79a0063d67d1fa2b4cf55db8 to your computer and use it in GitHub Desktop.
Sampling with Audio
import torch
import whisper
def prompt_template_fn(prompt="Describe the sound of the given file"):
system_message = "You are a helpful AI who follows instruction carefully"
prompt_prefix = f"""<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{prompt}"""
return prompt_prefix
def end_template():
return """
<|im_end|>
<|im_start|>assistant
"""
def load_audio_mels(file):
audio = whisper.load_audio(file)
audio = whisper.pad_or_trim(audio)
audio_mels = whisper.log_mel_spectrogram(audio, n_mels=128)
audio_mels = audio_mels.unsqueeze(0)
return audio_mels
def text_2_ids_and_attention_mask(tokenizer, input_txt, truncate=False):
txt = input_txt
res = tokenizer(txt, return_tensors="pt")
if truncate:
return res.input_ids[:, 1:], res.attention_mask[:, 1:]
return res.input_ids, res.attention_mask
@torch.no_grad()
def sample_with_audio(model, tokenizer, prompt, audio_file, device="cuda:0", iteration=50):
audio_mels = load_audio_mels(audio_file).to(device).half()
end_prompt_ids, end_prompt_attention_mask = text_2_ids_and_attention_mask(
tokenizer,
end_template(),
truncate=True,
)
prompt_ids, prompt_attention_mask = text_2_ids_and_attention_mask(
tokenizer,
prompt,
)
prompt_ids = prompt_ids.to(device)
prompt_attention_mask = prompt_attention_mask.to(device)
end_prompt_attention_mask = end_prompt_attention_mask.to(device)
end_prompt_ids = end_prompt_ids.to(device)
sampled_ids = None
prompt_embeds = None
end_prompt_embeds = None
audio_embeds = None
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
if audio_embeds is None:
audio_embeds = model.audio_encoder(audio_mels)
bs, audio_seq = audio_embeds.shape[:2]
mask_concat_args = [
prompt_attention_mask,
torch.ones(bs, audio_seq).to(audio_embeds.device),
end_prompt_attention_mask,
]
for _ in range(iteration):
if sampled_ids is not None:
mask_concat_args.append(torch.ones(bs, sampled_ids.shape[1]).to(audio_embeds.device))
attention_mask = torch.concat(
tuple(mask_concat_args),
dim=1,
)
if prompt_embeds is None:
prompt_embeds = model.llm.model.embed_tokens(prompt_ids)
if end_prompt_embeds is None:
end_prompt_embeds = model.llm.model.embed_tokens(end_prompt_ids)
sampled_ids_embeds = None
if sampled_ids is not None:
sampled_ids_embeds = model.llm.model.embed_tokens(sampled_ids)
embeds_concat_args = [
prompt_embeds,
audio_embeds.to(prompt_embeds.dtype),
end_prompt_embeds,
]
if sampled_ids_embeds is not None:
embeds_concat_args.append(sampled_ids_embeds)
inputs_embeds = torch.concat(
tuple(embeds_concat_args),
dim=1,
)
mout = model.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
logits = mout.logits
sampled = torch.multinomial(logits[:, -1, :].softmax(dim=-1), 1)
if sampled_ids is None:
sampled_ids = sampled
else:
sampled_ids = torch.cat((sampled_ids, sampled), dim=-1).to(device)
# print(prompt_ids.shape)
# print(end_prompt_ids.shape)
# print(sampled_ids.shape)
return torch.concat((
prompt_ids,
end_prompt_ids,
sampled_ids,
),dim=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment