Skip to content

Instantly share code, notes, and snippets.

@eschmidbauer
Created July 19, 2024 19:35
Show Gist options
  • Save eschmidbauer/c1bb441028a61db19d833a289688e8f6 to your computer and use it in GitHub Desktop.
Save eschmidbauer/c1bb441028a61db19d833a289688e8f6 to your computer and use it in GitHub Desktop.
import shutil
from transformers import AutoModelForCausalLM, GenerationConfig
import torch
from datasets import load_dataset
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import (WhisperForConditionalGeneration, WhisperModel,
WhisperProcessor, WhisperTokenizer)
MODEL = "large-v3"
model_train = WhisperModel.from_pretrained("model_train-large-v3-finetuned-2").cuda().train() # noqa
model_base = WhisperModel.from_pretrained("model_train-large-v3-finetuned-2").cuda().eval() # noqa
ds = load_dataset("google/fleurs", "en_us", split="train")
processor = WhisperProcessor.from_pretrained("model_train-large-v3-finetuned-2") # noqa
def get_sample(example):
waveform = example["audio"]["array"]
sampling_rate = example["audio"]["sampling_rate"]
input_features = processor(waveform, sampling_rate=sampling_rate, return_tensors="pt").input_features # noqa
return {"length": len(waveform) / sampling_rate, "input_features": input_features, "input_ids": processor.tokenizer.encode(example["raw_transcription"].lower())} # noqa
if not (".en" in MODEL):
print(processor.get_decoder_prompt_ids(language="english", task="transcribe")) # noqa
[processor.tokenizer.decode(i) for i in get_sample(ds[1])["input_ids"]]
def compute_partially_encoder(model, data, n_audio_ctx):
diffy = 2*n_audio_ctx - data.shape[2]
if diffy > 0:
data = nn.functional.pad(data, [0, diffy, 0, 0, 0, 0], "constant", 0.0)
elif diffy < 0:
data = data[:, :, :diffy]
if n_audio_ctx == 1500:
return model.encoder(data).last_hidden_state
input_embeds = nn.functional.gelu(model.encoder.conv1(data))
input_embeds = nn.functional.gelu(model.encoder.conv2(input_embeds))
input_embeds = input_embeds.permute(0, 2, 1)
embed_pos = model.encoder.embed_positions.weight[:n_audio_ctx]
hidden_states = input_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=model.encoder.dropout, training=model.encoder.training) # noqa
for idx, encoder_layer in enumerate(model.encoder.layers):
to_drop = False
if model.encoder.training:
dropout_probability = torch.rand([])
if dropout_probability < model.encoder.layerdrop:
to_drop = True
if to_drop:
layer_outputs = (None, None)
else:
if model.encoder.gradient_checkpointing and model.encoder.training:
layer_outputs = model.encoder._gradient_checkpointing_func(encoder_layer.__call__, hidden_states, None, None, False) # noqa
else:
layer_outputs = encoder_layer(hidden_states, None, layer_head_mask=None, output_attentions=False) # noqa
hidden_states = layer_outputs[0]
hidden_states = model.encoder.layer_norm(hidden_states)
return hidden_states
def compute_hidden_state_loss(model_train, model_base, optimizer, criterion, example):
optimizer.zero_grad()
n_ctx = int(round((1500.0 / 30.0) * example["length"]))
extra_ctx = torch.randint(-min(64, n_ctx // 3), min(64, n_ctx // 3), (1,)).item() # noqa
n_ctx += extra_ctx
input_features = example["input_features"].cuda()
input_ids = torch.tensor([example["input_ids"]], dtype=torch.long).cuda()
encoder_hidden_states_partial = compute_partially_encoder(model_train, input_features, n_ctx) # noqa
output_partial = model_train.decoder(input_ids=input_ids, encoder_hidden_states=encoder_hidden_states_partial, output_hidden_states=True) # noqa
with torch.no_grad():
encoder_hidden_states_full = compute_partially_encoder(model_base, input_features, 1500) # noqa
output_full = model_base.decoder(input_ids=input_ids, encoder_hidden_states=encoder_hidden_states_full, output_hidden_states=True) # noqa
loss = criterion(
# output_partial.hidden_states[-1],
# output_full.hidden_states[-1]
torch.cat(output_partial.hidden_states, 0),
torch.cat(output_full.hidden_states, 0)
)
loss.backward()
optimizer.step()
return loss
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model_train.parameters(), lr=1e-6)
writer = SummaryWriter()
writer.add_text("name", f"{MODEL} v3")
num_length = 0
step = 0
for epoch in range(1024):
pbar = tqdm(ds.shuffle(seed=epoch))
for example in pbar:
example = get_sample(example)
if example["length"] > 29.0:
continue
loss = compute_hidden_state_loss(model_train, model_base, optimizer, criterion, example) # noqa
step += 1
num_length += example["length"]
writer.add_scalar("loss/train", loss.item(), step)
writer.add_scalar("length/train", num_length, step)
writer.add_scalar("epoch/train", epoch, step)
pbar.set_description(f"Epoch {epoch}, Loss: {loss.item()}")
# Select an audio file and read it:
ds_eval = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # noqa
model = WhisperForConditionalGeneration.from_pretrained("model_train-large-v3-finetuned-2").eval().cuda() # noqa
for i in range(64):
audio_sample = ds_eval[i]["audio"]
waveform = audio_sample["array"]
sampling_rate = audio_sample["sampling_rate"]
input_features = processor(waveform, sampling_rate=sampling_rate, return_tensors="pt").input_features.cuda() # noqa
model.model = model_base.eval().cuda()
predicted_ids_base = model.generate(input_features)
model.model = model_train.eval().cuda()
predicted_ids_train = model.generate(input_features)
transcription = processor.batch_decode([predicted_ids_base[0], predicted_ids_train[0]], skip_special_tokens=True) # noqa
print(f"\n\nGrndTr: {ds_eval[i]['text'].lower()}\nModelB:{transcription[0]}\nModelT:{transcription[1]}") # noqa
# last
model = WhisperForConditionalGeneration.from_pretrained("model_train-large-v3-finetuned-2").eval().cpu() # noqa
model.model = model_train.eval().cpu()
model.save_pretrained(f"model_train-{MODEL}3")
shutil.make_archive(f"model_train-{MODEL}3", 'zip', f"model_train-{MODEL}3")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment